In [7]:
import torch as th 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import Dataset, random_split, DataLoader

import lightning as L
from lightning import Trainer, LightningModule
from lightning.pytorch.callbacks import RichProgressBar, RichModelSummary
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme

import os 
import cv2 
import numpy as np 
from random import randint
import matplotlib.pyplot as plt
from yoloface import face_analysis

import warnings 
warnings.filterwarnings("ignore")

In [8]:
PATH = r'/Users/suyashsachdeva/Desktop/Face Dataset/'
SAVE = r'/Users/suyashsachdeva/Desktop/GyanBhandar/oneshot/'

In [9]:
class Faces(Dataset):
    def __init__(self, path, size=20000, imgshp=64):
        self.data = self.datareader(path, size, imgshp)
        
        
    def datareader(self, path, size, shape):
            folders = [folder for folder in os.listdir(path) if folder[0]!="."]
            train = np.array(np.zeros((size, 3, 3, shape, shape)), dtype=np.float32) 
            for c in range(size):
                    pf = folders[randint(0, len(folders)-1)]
                    nf = folders[randint(0, len(folders)-1)]
                    files = os.listdir(SAVE+pf)
                    file  = os.listdir(SAVE+nf)
                    train[c, 0] = cv2.cvtColor(cv2.imread(SAVE+pf+"/"+files[randint(0, len(files)-1)]), cv2.COLOR_BGR2RGB).reshape(3, shape, shape)/255.0
                    train[c, 1] = cv2.imread(SAVE+pf+"/"+files[randint(0, len(files)-1)]).reshape(3, shape, shape)/255.0
                    train[c, 2] = cv2.imread(SAVE+nf+"/"+file[randint(0, len(file)-1)]).reshape(3, shape, shape)/255.0
            return train
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        train = th.from_numpy(self.data[idx])
        return {"anchor": train[0], "postive": train[1], "negative": train[2]}      

In [10]:
class ConvBlock(nn.Module):
    def __init__(self, n:int, infilter:int, outfilter:int, kernel:int, moment:float, alpha:float):
        super(ConvBlock, self).__init__()
        self.conv = nn.ModuleList([nn.Conv2d(infilter, outfilter, kernel, stride=2)])
        self.norm = nn.ModuleList([nn.BatchNorm2d(outfilter)])
        for _ in range(n-1):
            self.conv.append(nn.Conv2d(outfilter, outfilter, kernel))
            self.norm.append(nn.BatchNorm2d(outfilter, momentum=moment))
        self.zpad = nn.ZeroPad2d(int((kernel-1)//2))
        self.relu = nn.LeakyReLU(alpha)

    def forward(self, x):
        xl = []
        for conv, norm in zip(self.conv, self.norm):
            x = self.relu(norm(conv(self.zpad(x))))
            xl.append(x)
        return x + xl[0]


class CNN(nn.Module):
    def __init__(self, num:list=[1, 3, 3, 2, 2 ], filter:int=32, kernel:int=3, moment:float=0.7, alpha:float=0.03, dense:int=128, gf:int=2, drop:float=0.2):
        super(CNN, self).__init__()
        self.convblock = nn.ModuleList([ConvBlock(num[0], 3, filter, 7, moment, alpha)])
        for n in num[1:]:
            self.convblock.append(ConvBlock(n, filter, filter*gf, kernel, moment, alpha))
            filter = filter*gf
        self.dense =  nn.Linear(filter, dense)
        self.drop = nn.Dropout(drop)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        x = self.drop(self.flat(self.pool(x)))
        return self.dense(x)
    

class CNN(nn.Module):
    def __init__(self, num:list=[1, 3, 3, 2, 2 ], filter:int=64, kernel:int=3, moment:float=0.7, alpha:float=0.03, dense:int=128, gf:int=2, drop:float=0.2):
        super(CNN, self).__init__()
        self.convblock = nn.ModuleList([ConvBlock(num[0], 3, filter, 7, moment, alpha)])
        for n in num[1:]:
            self.convblock.append(ConvBlock(n, filter, filter*gf, kernel, moment, alpha))
            filter = filter*gf
        self.dense =  nn.Linear(filter, dense)
        self.drop = nn.Dropout(drop)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        for convblock in self.convblock:
            x = convblock(x)
        x = self.drop(self.flat(self.pool(x)))
        return self.dense(x)

In [11]:
class OneShot(LightningModule):
    def __init__(self, dataset:Dataset, model:nn.Module, split:float=0.7, batchsize:int=256, 
                 shape:int=64, learning_rate:float=1e-3, decay_step:int=10):
        super(OneShot, self).__init__()
        self.traindata, self.validata = random_split(dataset, [int(dataset.__len__()*split), dataset.__len__()-int(dataset.__len__()*split)])
        self.model = model
        self.split = split
        self.sizebatch = int(batchsize*1.2)
        self.loss = nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance(), margin=1.0)
        self.shp = shape
        self.learate = learning_rate

    
    def training_step(self, batch, batch_idx):
        a = self.model(batch["anchor"].reshape(-1, 3, self.shp, self.shp))
        p = self.model(batch["postive"].reshape(-1, 3, self.shp, self.shp))
        n = self.model(batch["negative"].reshape(-1, 3, self.shp, self.shp))
        loss = self.loss(a, p, n)

        cur_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", cur_lr, prog_bar=True)
        self.log("batch_size", self.sizebatch, prog_bar=True)
        self.log_dict({ "train_loss": loss}, on_epoch=True, on_step=True,  prog_bar=True, logger=False, batch_size=self.sizebatch)
        return loss

    def validation_step(self, batch, batch_idx):
        a = self.model(batch["anchor"].reshape(-1, 3, self.shp, self.shp))
        p = self.model(batch["postive"].reshape(-1, 3, self.shp, self.shp))
        n = self.model(batch["negative"].reshape(-1, 3, self.shp, self.shp))
        loss = self.loss(a, p, n)
        self.log_dict({"valid_loss": loss},  on_step=True, on_epoch=True, prog_bar=True, logger=False)
        return loss

    def configure_optimizers(self):
        self.optimizer = optim.Adam(self.parameters(), lr=self.learate)
        self.sch = optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.96)
        return {"optimizer": self.optimizer, 
                "lr_scheduler": {"scheduler": self.sch, "monitor": "train_loss"}}
    
    def train_dataloader(self):
        self.sizebatch = int(self.sizebatch/1.2)
        return DataLoader(self.traindata, batch_size=self.sizebatch, shuffle=True)
        
    
    def val_dataloader(self):
        return DataLoader(self.validata, batch_size=self.sizebatch, shuffle=True)


In [12]:
model = CNN()
dataset = Faces(SAVE)

In [13]:
def callmeback():
    progress_bar = RichProgressBar(
                theme=RichProgressBarTheme(description= "black", progress_bar="green_yellow", 
                progress_bar_finished ="green", batch_progress="black",  time="blue", 
                processing_speed="black",metrics="red", ))
    summary  = RichModelSummary()
    return [progress_bar, summary]

In [14]:
x = 256*1.2
l = 1
for c in range(20):
    x = int(x/1.2)
    l = l* (0.96**10)
    print(c, x, l)

0 256 0.6648326359915008
1 213 0.44200243387940735
2 177 0.2938576432307054
3 147 0.19536615155531986
4 122 0.12988579352203833
5 101 0.08635231448510454
6 84 0.057409836863099105
7 70 0.03816793317353621
8 58 0.025375287622109523
9 48 0.016870319358849577
10 40 0.01121593888936241
11 33 0.0074567222169343965
12 27 0.004957472287340882
13 22 0.0032958893686476534
14 18 0.002191214816894383
15 15 0.0014567911227395263
16 12 0.0009685222822199372
17 10 0.0006439052219047851
18 8 0.0004280892060076505
19 6 0.0002846076752695749


In [15]:
path = r"/Users/suyashsachdeva/Desktop/DeepHub/CNN/One_shot/lightning_logs/version_29/checkpoints/epoch=30-step=2116.ckpt"
model =CNN()
learn = OneShot.load_from_checkpoint(path, dataset=dataset, model=model, learning_rate=0.000294, batchsize=146)
trainer = Trainer(min_epochs=120, max_epochs=170, accelerator="mps", reload_dataloaders_every_n_epochs=10)
trainer.fit(learn)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | CNN                           | 19.7 M
1 | loss  | TripletMarginWithDistanceLoss | 0     
--------------------------------------------------------
19.7 M    Trainable params
0         Non-trainable params
19.7 M    Total params
78.766    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [16]:
th.save(model.to(th.device("cpu")), "./newvision.pt")

batch 146, lr= 0.000294