In [None]:
from fastai.tabular.all import * 
from fastai.test_utils import show_install

show_install()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def set_seed_value(seed=718):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed_value()

In [None]:
path = Path('../input/tabular-playground-series-feb-2022')
Path.BASE_PATH = path
path.ls()

In [None]:
train_df = pd.read_csv(os.path.join(path, 'train.csv')).set_index("row_id")
test_df = pd.read_csv(os.path.join(path, 'test.csv')).set_index("row_id")
sample_submission = pd.read_csv(os.path.join(path, 'sample_submission.csv'))

In [None]:
dep_var= 'target'

In [None]:
# train_df = train_df[(train_df[dep_var] == 'Enterococcus_hirae') | (train_df[dep_var] == 'Escherichia_coli')]
len(train_df)

In [None]:
cont_vars, cat_vars = cont_cat_split(train_df, dep_var= dep_var)
len(cat_vars), len(cont_vars)

In [None]:
class ReadTabBatchIdentity(ItemTransform):
    "Read a batch of data and return the inputs as both `x` and `y`"
    def __init__(self, to): store_attr()

    def encodes(self, to):
        if not to.with_cont: res = (tensor(to.cats).long(),) + (tensor(to.cats).long(),)
        else: res = (tensor(to.cats).long(),tensor(to.conts).float()) + (tensor(to.cats).long(), tensor(to.conts).float())
        if to.device is not None: res = to_device(res, to.device)
        return res
    
class TabularPandasIdentity(TabularPandas): pass

In [None]:
@delegates()
class TabDataLoaderIdentity(TabDataLoader):
    "A transformed `DataLoader` for AutoEncoder problems with Tabular data"
    do_item = noops
    def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
        if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatchIdentity(dataset)
        super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)

    def create_batch(self, b): 
        return self.dataset.iloc[b]

In [None]:
TabularPandasIdentity._dl_type = TabDataLoaderIdentity

In [None]:
to = TabularPandasIdentity(train_df, 
                           [Categorify, FillMissing, Normalize], 
                           cat_vars, cont_vars, 
                           device=device,
                           splits=RandomSplitter(seed=32)(train_df))
dls = to.dataloaders(bs=1024)
dls.n_inp = 2

len(dls.train), len(dls.valid)

In [None]:
batch = dls.one_batch()
batch[1].min(), batch[1].max()

In [None]:
means = pd.DataFrame.from_dict({k:[v] for k,v in dls.train_ds.means.items()})
stds = pd.DataFrame.from_dict({k:[v] for k,v in dls.train_ds.stds.items()})

low = (train_df[cont_vars].min().to_frame().T.values - means.values) / stds.values
high = (train_df[cont_vars].max().to_frame().T.values - means.values) / stds.values

In [None]:
total_cats = {k:len(v) for k,v in to.classes.items()}
total_cats

In [None]:
class RecreatedLoss(Module):
    "Measures how well we have created the original tabular inputs"
    def __init__(self, cat_dict, reduction='mean'):
        ce = CrossEntropyLossFlat(reduction='none')
        mse = MSELossFlat(reduction='none')
        store_attr('cat_dict,ce,mse,reduction')

    def forward(self, preds, cat_targs, cont_targs):
        cats, conts = preds
        tot_ce, pos = [0], 0
        for i, (k,v) in enumerate(self.cat_dict.items()):
            tot_ce += [self.ce(cats[:, pos:pos+v], cat_targs[:,i])]
            pos += v
            
        tot_ce = torch.zeros(cats.shape[0], device=device)
        if len(self.cat_dict.items())>0:
            tot_ce += torch.stack(tot_ce, axis=1).mean(axis=1)
        cont_loss = self.mse(conts, cont_targs).view(conts.shape).mean(axis=1)
        
        total_loss = torch.stack([tot_ce, cont_loss], axis=1).sum(axis=1)
        # total_loss = cont_loss
        
        if self.reduction == 'mean':
            return total_loss.mean()
        elif self.reduction == 'sum':
            return total_loss.sum()
        
        return total_loss
    
loss_func = RecreatedLoss(total_cats)

In [None]:
class VAERecreatedLoss(Module):
    "Measures how well we have created the original tabular inputs, plus the KL Divergence with the unit normal distribution"
    def __init__(self, cat_dict, dataset_size, bs=1024):
        ce = CrossEntropyLossFlat(reduction='sum')
        mse = MSELossFlat()
        store_attr('ce,mse,dataset_size,bs')
      
    def forward(self, preds, cat_targs, cont_targs):
        
        if(len(preds) == 5):
            cats,conts, mu, logvar, kl_weight = preds
        else:
            cats,conts, mu, logvar = preds
            kl_weight = 1

        r_loss = self.mse(conts, cont_targs) /(1 + len(conts))
        r_loss *= self.dataset_size / self.bs
        
        kld_loss =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kld_loss = kl_weight * torch.mean(kld_loss)
       # if kl_weight > 0:
       #     print('kl_weight ', kl_weight.item(), 'kld_loss ', kld_loss.item() , ' r_loss ', r_loss.item())
        return kld_loss + r_loss
        
vae_loss_func = VAERecreatedLoss(total_cats, train_df.shape[0], bs=dls.bs)

In [None]:
class BatchSwapNoise(Module):
    "Swap Noise Module"
    def __init__(self, p): store_attr()

    def forward(self, x):
        if self.training:
            mask = torch.rand(x.size()) > (1 - self.p)
            l1 = torch.floor(torch.rand(x.size()) * x.size(0)).type(torch.LongTensor)
            l2 = (mask.type(torch.LongTensor) * x.size(1))
            res = (l1 * l2).view(-1)
            idx = torch.arange(x.nelement()) + res
            idx[idx>=x.nelement()] = idx[idx>=x.nelement()]-x.nelement()
            return x.flatten()[idx].view(x.size())
        else:
            return x

In [None]:
class TabularAE(TabularModel):
    "A simple AutoEncoder model"
    def __init__(self, emb_szs, n_cont, hidden_size, cats, low, high, ps=0.2, embed_p=0.01, bswap=None, act_cls=Mish()):
        super().__init__(emb_szs, n_cont, out_sz=hidden_size, layers=[1024, 512, 256], embed_p=embed_p, act_cls=act_cls)
        
        self.bswap = bswap
        self.cats = cats
        self.activation_cats = sum([v for k,v in cats.items()])
        
        self.layers = nn.Sequential(*L(self.layers.children())[:-1] + nn.Sequential(LinBnDrop(256, hidden_size, p=ps, act=Mish())))
        
        if(bswap != None): self.noise = BatchSwapNoise(bswap)
        self.decoder = nn.Sequential(
            LinBnDrop(hidden_size, 256, p=ps, act=act_cls),
            # tab_vae1
            LinBnDrop(256, 512, p=ps, act=act_cls),
            LinBnDrop(512, 1024, p=ps, act=act_cls),
        )
        
        self.decoder_cont = nn.Sequential(
            LinBnDrop(1024, 512, p=ps, act=act_cls),
            LinBnDrop(512, 128, p=ps, act=act_cls),
            LinBnDrop(128, n_cont, p=ps, bn=False, act=None),
            SigmoidRange(low=low, high=high)
        )
        
        self.decoder_cat = nn.Sequential(
            LinBnDrop(1024, 512, p=ps, act=act_cls),
            LinBnDrop(512, 128, p=ps, act=act_cls),
            LinBnDrop(128, self.activation_cats, p=ps, bn=False, act=None)
        )
        
    def forward(self, x_cat, x_cont=None, do_encode=False):
        if(self.bswap != None):
            x_cat = self.noise(x_cat)
            x_cont = self.noise(x_cont)
            
        encoded = super().forward(x_cat, x_cont)
        if do_encode: 
            return encoded # return the representation
        decoded_trunk = self.decoder(encoded)
        decoded_cats = self.decoder_cat(decoded_trunk)
        decoded_conts = self.decoder_cont(decoded_trunk)
        return decoded_cats, decoded_conts

In [None]:
class TabularVAE(TabularModel):
    def __init__(self, emb_szs, n_cont, hidden_size, cats, low, high, ps=0.2, embed_p=0.01, bswap=None, act_cls=Mish()):
        super().__init__(emb_szs, n_cont, layers=[300,200,100], out_sz=2*hidden_size, embed_p=embed_p, act_cls=act_cls)
        
        self.bswap = bswap
        self.cats = cats
        self.activation_cats = sum([v for k,v in cats.items()])
        
        self.logVarLayer = LinBnDrop(2*hidden_size, hidden_size, p=ps)
        self.muLayer = LinBnDrop(2*hidden_size, hidden_size, p=ps)
        
        if self.bswap != None: 
            self.noise = BatchSwapNoise(self.bswap)
            
        self.decoder = nn.Sequential(
            LinBnDrop(hidden_size, 256, p=ps, act=act_cls),
            LinBnDrop(256, 512, p=ps, act=act_cls),
            LinBnDrop(512, 256, p=ps, act=act_cls)
        )
        
        self.decoder_cont = nn.Sequential(
            LinBnDrop(256, n_cont, p=ps, bn=False, act=None),
            SigmoidRange(low=low, high=high)
        )
        
        self.decoder_cat = nn.Sequential(
            LinBnDrop(256, self.activation_cats, p=ps, bn=False, act=None)
        )
        
        self.decoder = nn.Sequential(
            LinBnDrop(hidden_size, 256, p=ps, act=act_cls),
            LinBnDrop(256, 512, p=ps, act=act_cls),
            LinBnDrop(512, 1024, p=ps, act=act_cls)
        )
        
        self.decoder_cont = nn.Sequential(
            LinBnDrop(1024, n_cont, p=ps, bn=False, act=None),
            SigmoidRange(low=-10, high=10)
        )
        
        self.decoder_cat = LinBnDrop(1024, self.activation_cats, p=ps, bn=False, act=None)
        
    
    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std, device=device)
        return mu + eps * std
    
    def bottleneck(self, x):
        mu = self.muLayer(x)
        logvar = self.logVarLayer(x)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    def forward(self, x_cat, x_cont=None, do_encode=False):
        if(self.bswap != None):
            x_cencodedat = self.noise(x_cat)
            x_cont = self.noise(x_cont)
        
        encoded = super().forward(x_cat, x_cont)
        z, mu, logvar = self.bottleneck(encoded)
        
        if(do_encode): 
            return z
        
        decoded_trunk = self.decoder(z)
        decoded_cats = self.decoder_cat(decoded_trunk)
        decoded_conts = self.decoder_cont(decoded_trunk)
        return decoded_cats, decoded_conts, mu, logvar

In [None]:
class MSEMetric(Metric):
    def __init__(self): 
        self.preds = []
        mse = MSELossFlat()
        store_attr('mse')
        
    def reset(self): 
        self.preds.clear()
        
    def accumulate(self, learn):
        _, conts, _, _ = learn.pred
        _, cont_targs = learn.y
        r_loss = self.mse(conts, cont_targs) /(1 + len(conts))
        
        self.preds.append(to_detach(r_loss))
    @property
    def value(self):
        return np.array(self.preds).mean()
    
class CEMetric(Metric):
    def __init__(self): self.preds = []
    def accumulate(self, learn):
        cats, conts, mu, logvar = learn.pred
        cat_targs, cont_targs = learn.y
        CE = cats.new([0])
        pos=0
        for i, (k,v) in enumerate(total_cats.items()):
            CE += F.cross_entropy(cats[:, pos:pos+v], cat_targs[:, i], reduction='sum')
            pos += v
 
        norm = cats.new([len(total_cats.keys())])
        self.preds.append(to_detach(CE/norm))
        
    @property
    def value(self):
        return np.array(self.preds).mean()
    
class KldMetric(Metric):
    def __init__(self):
        self.preds = []
    
    def reset(self): 
        self.preds.clear()
        
    def accumulate(self, learn):
        _, _,mu,logvar = learn.pred
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        self.preds.append(to_detach(torch.mean(KLD)))
        
    @property
    def value(self):
        return np.array(self.preds).mean()
    
class MuMetric(Metric):
    def __init__(self): 
        self.preds = []
    
    def reset(self): 
        self.preds.clear()
        
    def accumulate(self, learn):
        _, _, mu,_ = learn.pred
        self.preds.append(to_detach(mu.mean()))
    @property
    def value(self):
        return np.array(self.preds).mean()
    
class StdMetric(Metric):
    def __init__(self): 
        self.preds = []
    
    def reset(self): 
        self.preds.clear()
        
    def accumulate(self, learn):
        _, _,_, logvar = learn.pred
        self.preds.append(to_detach(logvar.mean()))
    @property
    def value(self):
        return np.array(self.preds).mean()
    
class CeMetric(Metric):
    def __init__(self): self.preds = []
    def accumulate(self, learn):
        cats, conts, mu, logvar = learn.pred
        cat_targs, cont_targs = learn.y
        CE = cats.new([0])
        pos=0
        for i, (k,v) in enumerate(total_cats.items()):
            CE += F.cross_entropy(cats[:, pos:pos+v], cat_targs[:, i], reduction='sum')
            pos += v

        norm = cats.new([1+len(total_cats.keys())])
        self.preds.append(to_detach(CE/norm))
    @property
    def value(self):
        return np.array(self.preds).mean()    

In [None]:
emb_szs = get_emb_sz(to.train)
latent_space_size = 8
low_values = tensor(low).to(device)
high_values = tensor(high).to(device)

modelVAE = TabularVAE(emb_szs, len(cont_vars), 
                      latent_space_size, 
                      ps=0.0, cats=total_cats, 
                      embed_p=0.0,
                      bswap=None, 
                      low=low_values, high=high_values)

In [None]:
class AnnealedLossCallback(Callback):
    def after_pred(self):
        if (len(self.learn.pred)):
            kl = self.learn.pred[0].new(1)
            kl[0] = self.opt.hypers[0]['kl_weight']

            self.learn.pred = self.learn.pred + (kl,)
        
    def after_batch(self):
        if(len(self.learn.pred)):
            cats, conts, mu, logvar, _ = self.learn.pred
        else:
            cats, conts, mu, logvar = self.learn.pred
            
        self.learn.pred = (cats, conts, mu, logvar)

In [None]:
f = combine_scheds([.1, .3, .6], [SchedCos(0,0), SchedCos(0,1), SchedNo(1,1)])
callbacks = [ParamScheduler({'kl_weight': f }), AnnealedLossCallback()]

learn = Learner(dls, modelVAE, loss_func= vae_loss_func, 
                cbs=callbacks, 
                wd=0.01,
                metrics=[MSEMetric(), KldMetric(),  MuMetric(), StdMetric()], 
                opt_func=ranger)

In [None]:
# learn.lr_find()

In [None]:
# lear#n.fit_one_cycle(200, 5e-3, cbs=SaveModelCallback(fname='tab_vae1', with_opt=True)) 
learn.fit_flat_cos(200, lr=5e-5, cbs=SaveModelCallback(fname='tab_vae1', with_opt=True)) 

In [None]:
#learn.load('tab_vae1')

In [None]:
def get_compressed_representation(l, df):
    dl = l.dls.test_dl(df)
    comp_reps = []
    l.model.eval()
    l.model.cuda()
    for batch in dl:
        with torch.no_grad():
            act_rep = l.model(*batch[:2], True).cpu().numpy()
            comp_reps.append(act_rep)
    return np.concatenate(comp_reps)

In [None]:
dl = learn.dls.test_dl(train_df)

In [None]:
outs = []
learn.model.eval()
learn.model.cuda()
for batch in dl:
    with torch.no_grad():
        out = learn.model(*batch[:2], True).cpu().numpy()
        outs.append(out)
outs = np.concatenate(outs)

outs.shape

In [None]:
(cat_preds, cont_preds, mu, logvar), (cat_targs, cont_targs) = learn.get_preds(dl=dl)

In [None]:
ys = train_df[dep_var].to_numpy()
len(outs), len(ys)

In [None]:
para_col_names = [str(x).zfill(3) for x in range(0,latent_space_size)]
encoded_train_df = pd.DataFrame(columns=[dep_var] + para_col_names)
encoded_train_df[dep_var] = train_df[dep_var]
encoded_train_df[para_col_names] = outs
encoded_train_df[para_col_names] = encoded_train_df[para_col_names].astype(np.float32)

encoded_train_df.describe()

In [None]:
import seaborn as sns
x_col = '000'
y_col = '001'
df_x = encoded_train_df[[x_col, y_col, dep_var]].copy()

plt.figure(figsize=(14,7))
sns.scatterplot(data=df_x, x=x_col, y=y_col, hue=dep_var, alpha=0.8, palette="bright")
plt.title("AutoEncoder")
plt.show()

In [None]:
# new_names =  [str(x).zfill(3) for x in range(0,128)]
splits = RandomSplitter()(range_of(encoded_train_df))
to2 = TabularPandas(encoded_train_df, 
                   procs = [Normalize], 
                   cont_names=para_col_names, 
                   splits=splits, 
                   y_names=dep_var,
                   y_block=CategoryBlock())

In [None]:
dls2 = to2.dataloaders(bs=2048)
len(dls2.train), len(dls2.valid)

In [None]:
def accuracy(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp.argmax(dim=axis), targ)
    return (pred == targ).float().mean()

In [None]:
learn2 = tabular_learner(dls2, 
                         layers=[1024,1024,512,256,128], 
                         metrics=[accuracy])
learn2.summary()

In [None]:
learn2.lr_find()

In [None]:
learn2.fit_one_cycle(250, 7e-3, wd=0.01, cbs=SaveModelCallback(fname='kaggle_tps_feb_2022_vae', with_opt=True)) 

In [None]:
learn2.load('kaggle_tps_feb_2022_vae')

In [None]:
interp = ClassificationInterpretation.from_learner(learn2)
interp.plot_confusion_matrix(normalize=True, norm_dec=3, figsize=(10, 10))

In [None]:
#outs = get_compressed_representation(learn, test_df)

dl = learn.dls.test_dl(test_df)
outs = []
learn.model.eval()
learn.model.cuda()
for batch in dl:
    with torch.no_grad():
        out = learn.model(*batch[:2], True).cpu().numpy()
        outs.append(out)
outs = np.concatenate(outs)

outs.shape

encoded_test_df = pd.DataFrame(columns=para_col_names)
encoded_test_df[para_col_names] = outs
encoded_test_df[para_col_names] = encoded_train_df[para_col_names].astype(np.float32)

encoded_test_df.head()

In [None]:
dlt = learn2.dls.test_dl(encoded_test_df, bs=4096) 
preds, _ = learn2.get_preds(dl=dlt) 
print(preds[:2])

In [None]:
decoded_preds_str = dls2.train.categorize.vocab.map_ids(np.argmax(preds, axis=1))
sample_submission[dep_var] =  decoded_preds_str
sample_submission.to_csv("submission.csv", index=False)
sample_submission.head(10)

In [None]:
dep_var_dist = pd.DataFrame({
    'target_count': train_df[dep_var].value_counts(),
    'target_quota (%)': train_df[dep_var].value_counts() / train_df.shape[0] * 100,
})


dep_var_dist['pred_count'] = pd.Series(decoded_preds_str, index=test_df.index).value_counts()
dep_var_dist['pred_quota (%)'] = dep_var_dist['pred_count'] / len(test_df) * 100
dep_var_dist.sort_index().head(11)

In [None]:
!ls -la 