In [None]:
# %%
import torch 
import torch.nn as nn

from functools import reduce
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np

In [None]:


def oneminus(args):
    return 1-args


def calculate_loss(*Args):
    
    return torch.sqrt(torch.abs(torch.sum(reduce(torch.add,[torch.pow(arg,2).view(*list([1]*i+[arg.shape[0]]+[1]*(len(Args)-1-i)+[-1])) for i,arg in enumerate(Args)]).sub_(
                            torch.pow(reduce(torch.add,[arg.view(*list([1]*i+[arg.shape[0]]+[1]*(len(Args)-1-i)+[-1])) for i,arg in enumerate(Args)]),2),alpha=1/len(Args)),dim=-1)))

def lossfn(*args):
    return oneminus(calculate_loss(*args))

In [None]:

class PTLModule(pl.LightningModule):
    def __init__(self,
                batch_size=16,
                learning_rate=0.00001,
                n=6,):
        super().__init__(
           
        )
        self.save_hyperparameters()
        self.emb=nn.Embedding(1000, 32)
        self.layer1 = nn.Linear(32, 128)
        self.layer2 = nn.Linear(128, 64)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.batch_size=batch_size
        self.learning_rate=learning_rate
        self.calculate_loss=lossfn
        self.n=n
        torch.autograd.set_detect_anomaly(True)
        self.loss=nn.CrossEntropyLoss()
       
    def forward(self, x):
        
        x = self.layer1(x)
        x = self.layer2(x)
        return x
    def setup(self, stage):
        self.train_dataset = torch.utils.data.TensorDataset(torch.randint(0, 1000, (10000,)))

    def train_dataloader(self,batch_size=32):
      
        import torch.utils.data.dataloader as dataloader


        return dataloader.DataLoader(self.train_dataset,batch_size=self.batch_size,shuffle=True,num_workers=8,drop_last=True)


    def training_step(self, batch, batch_idx):
        x=self.emb(batch[0]) # should be Bxf 
        nx=[self(x+torch.randn_like(x))]*self.n # should be Bxf

        logits=torch.mul(torch.nan_to_num(self.calculate_loss(*nx)),self.logit_scale.exp())
        labels=torch.ones_like(batch[0],dtype=torch.float)
        while len(labels.shape)<len(logits.shape):
            labels=torch.diag_embed(labels)
        #labels=torch.nan_to_num(labels)
        
        loss = self.loss(logits, labels)
        self.log("loss",loss,enable_graph=False)
        return  {"loss":loss, "labels":batch[0], "embs":nx[0]}  

    def training_epoch_end(self, outputs):
   
        alllabels=torch.cat([x['labels'] for x in outputs],dim=0)
        allembs=torch.cat([x['embs'] for x in outputs],dim=0)
       
        from sklearn.linear_model import LogisticRegression

        reg =LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=0, n_jobs=-1)
        reg.fit(allembs.detach().cpu().numpy(), alllabels.detach().cpu().numpy())
        
        self.log("score",reg.score(allembs.detach().cpu().numpy(), alllabels.detach().cpu().numpy()),on_step=False,on_epoch=True,prog_bar=True,logger=True)
            
    def configure_optimizers(self):
        
        optimizer = torch.optim.AdamW(
            [p for p in self.parameters()], lr=self.hparams.learning_rate, eps=10e-8,            )
        return [optimizer]


from functools import reduce
class PTLModuleStock(PTLModule):
    def training_step(self, batch, batch_idx):
     
        x=self.emb(batch[0]) # should be Bxf 
        nx=[self(x+torch.randn_like(x))]*self.n # should be Bxf
        loss=reduce(torch.add,[self.loss(item@ x.T *self.logit_scale.exp(),torch.arange(batch[0].shape[0],device=self.device),alpha=self.alpha) for x in nx  for item in nx])
     
        self.log("loss",loss,enable_graph=False)
        return  {"loss":loss, "labels":batch[0], "embs":nx[0]}  
    




In [None]:

#we're going to create some cool graphs, each with epochs : score for each of the 6 models and for each method. 
results={n:{ i:{} for i in range(17)} for n in range(2,14)}

for n in range(2,14):
    i=5
    model=PTLModule(logitsversion=i)

    trainer = Trainer(
        gpus=1,
        max_epochs=20,
        logger=TensorBoardLogger("tb_logs"),
        auto_scale_batch_size="binsearch",
        auto_lr_find=False,
    )
    trainer.tune(model)

    trainer.fit(model)
    results[n][i]=model.trainer.logged_metrics


for n in range(2,14):
    #do benchmark first
    model=PTLModuleStock(n=n)
    trainer = Trainer(
        gpus=1,
        max_epochs=20,
        logger=TensorBoardLogger("tb_logs"),
        auto_scale_batch_size="binsearch",
        auto_lr_find=False)
    try:
        trainer.tune(model)

        trainer.fit(model)
        results[n]["stock"]=model.trainer.logged_metrics
    except Exception as e:
        print(e)
        results[n]["stock"]=None

#save results
import pickle
with open("base-v-.pkl","wb") as f:
    pickle.dump(results,f)
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle


with open("base-v-.pkl","rb") as f:
    base=pickle.load(f)
#pickle.load("base-v-5.pkl")
stockscores=[base[n]["stock"]["score"].item()  for n in base]
scores=[base[n][5]["score"].item()  for n in base]

#plot n against score for both 

plt.plot(np.arange(2,14),scores,label="v5")
plt.plot(np.arange(2,14),stockscores,label="stock")
plt.legend()
plt.xlabel("n")
plt.ylabel("score")
plt.title("Score for different n")
plt.savefig("scorev5.png")