In [1]:
import openml
import pandas as pd
import datetime
import os
import json
import transformers
from src.tabby import MOEModelForCausalLM
import torch

# Load real data

In [2]:
dataset = openml.datasets.get_dataset('diabetes')
df, _, _, _ = dataset.get_data(dataset_format="dataframe")

# rename and clean columns
cols = ['pregnancies', 'glucose-plasma', 'blood-pressure', 'skin-thickness', 'insulin', 'BMI', 'pedigree', 'age', 'diagnosis']
df.columns = cols 
df['diagnosis'] = df['diagnosis'].map(lambda x: 'positive' if x=='tested_positive' else 'negative')

# shuffle data
df = df.sample(frac=1, random_state=42, ignore_index=True)

# split into train/val/test sets used by paper
n = len(df)
train_size = int(0.75 * n)
val_size = int(0.075 * n)
train = df.iloc[:train_size, :]
val = df.iloc[train_size:train_size+val_size, :]
test = df.iloc[train_size+val_size:, :]
print('train', train.shape, 'val', val.shape, 'test', test.shape)

train (576, 9) val (57, 9) test (135, 9)


# Load model

In [3]:
# tokenizer 
tokenizer = transformers.AutoTokenizer.from_pretrained('distilgpt2', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
special_tokens_dict = {"bos_token": "<BOS>", 'eos_token': '<EOC>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

# model
num_experts = train.shape[1]
model = transformers.AutoModelForCausalLM.from_pretrained('distilgpt2', device_map='cuda')
model.resize_token_embeddings(len(tokenizer))
model = MOEModelForCausalLM(model, num_experts=num_experts, multihead=True,
    pad=tokenizer.pad_token_id, eoc=len(tokenizer)-1)

# model checkpoint 
ckpt_path = '/mnt/data/sonia/ckpts/diabetes-new/plain/emh/1e-4/1/model.pt'
sd = torch.load(ckpt_path)
for name, param in model.named_parameters():
    param.data.copy_(sd[name])
    
# prepare model to perform synthesis
token_heads = list(range( len(train.columns) ))
column_names_tokens = tokenizer(list(train.columns), add_special_tokens=False).input_ids
model.set_generation_mode(token_heads=token_heads, column_names_tokens=column_names_tokens)

  sd = torch.load(ckpt_path)


# Perform Synthesis

In [4]:
n_samples = 5 

outputs = []
for _ in range(n_samples):
    inputs = torch.full((1, 1), tokenizer.bos_token_id).to(model.device)
    toks = model.generate(inputs, do_sample=True, num_beams=1, max_length=1000)[...,1:]
    outputs.append(tokenizer.batch_decode(toks)[0]) 
    
# parse the lines output by model
def parse_line(l):
    entries = l.split('<EOC>')[:-1] # remove newline at end
    words = [c.split(' ') for c in entries] #'name', 'is', 'value'
    d = dict()
    for c in words:
        if c[0] in cols and len(c) == 3 and c[0] not in d: # keep only first occurence
            d[c[0]] = c[2]
            
    if set(d.keys()) == set(cols):
        return d 
    else:
        return None

print('\n\n**Synthetic dataset:**')
dicts = [parse_line(out) for out in outputs]
pd.DataFrame(dicts)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask



**Synthetic dataset:**


Unnamed: 0,pregnancies,glucose-plasma,blood-pressure,skin-thickness,insulin,BMI,pedigree,age,diagnosis
0,1.0,109.0,46.0,19.0,78.0,28.5,0.219,22.0,negative
1,6.0,189.0,110.0,31.0,0.0,28.5,0.68,37.0,negative
2,4.0,95.0,70.0,32.0,0.0,32.1,0.612,24.0,negative
3,0.0,73.0,0.0,0.0,0.0,21.1,0.342,25.0,negative
4,5.0,124.0,74.0,0.0,0.0,34.0,0.22,38.0,positive
