In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchmetrics


class LinearWeightBlock(nn.Module):
    def __init__(self, dim_embedding = 512, n_head = 2) -> None:
        super().__init__()
        
        self.dim_embedding = dim_embedding
        
        self.linear = nn.Linear(dim_embedding, dim_embedding//16)
        
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim_embedding//16, nhead=n_head), num_layers=2
        )
        
        self.bilinear = nn.Bilinear(dim_embedding//16, dim_embedding//16, dim_embedding//16)
        
        self.dropout = nn.Dropout(0.1)
        
        self.gelu = nn.GELU()
        
        self.layer_norm = nn.LayerNorm(dim_embedding//16)
        
    def forward(self, x):
        x = self.linear(x)
        x = self.layer_norm(x)
        x_0 = self.transformer_encoder(x)
        x = self.bilinear(x_0 + self.dropout(x_0), x)
        x = self.gelu(x)
        return x
    
    
class Classifier(nn.Module):
    def __init__(self, dim_embedding = 512, n_head = 2, n_class = 10) -> None:
        super().__init__()
        
        self.dim_embedding = dim_embedding
        self.n_class = n_class
        
        self.linear_weight_block = LinearWeightBlock(dim_embedding, n_head)
        
        self.mlp = nn.Sequential(
            nn.Linear(dim_embedding//16, dim_embedding//32),
            nn.GELU(),
            nn.LayerNorm(dim_embedding//32),
            nn.Linear(dim_embedding//32, n_class)
        )
        
    def forward(self, x):
        x = self.linear_weight_block(x)
        x = self.mlp(x)
        return x
    
    
class LitClassifier(pl.LightningModule):
    def __init__(self, dim_embedding = 512, n_head = 2, n_class = 10) -> None:
        super().__init__()
        
        self.model = Classifier(dim_embedding, n_head, n_class)
        
        self.loss = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=n_class)
        self.f1 = torchmetrics.F1Score(task="multiclass",num_classes=n_class)
        self.precision = torchmetrics.Precision(task="multiclass",num_classes=n_class)
        self.recall = torchmetrics.Recall(task="multiclass",num_classes=n_class)

        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        
        self.log("train_loss", loss)
        self.log("train_acc", self.accuracy(y_hat, y))
        self.log("train_f1", self.f1(y_hat, y))
        self.log("train_precision", self.precision(y_hat, y))
        self.log("train_recall", self.recall(y_hat, y))
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        
        self.log("val_loss", loss)
        self.log("val_acc", self.accuracy(y_hat, y))
        self.log("val_f1", self.f1(y_hat, y))
        self.log("val_precision", self.precision(y_hat, y))
        self.log("val_recall", self.recall(y_hat, y))
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer], [scheduler]
    
    
    

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset, random_split

# X = torch.randn(1000, 1280, 512)
# y = torch.randint(0, 10, (1000,))


# dataset = TensorDataset(X, y)

# KFold = KFold(n_splits=5, shuffle=True, random_state=42)

# for train_idx, val_idx in KFold.split(dataset):

#     train_dataset = torch.utils.data.Subset(dataset, train_idx)
#     val_dataset = torch.utils.data.Subset(dataset, val_idx)

#     train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4)
#     val_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)

#     model = LitClassifier()
    
#     from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
#     early_stopping = EarlyStopping('val_loss', patience=3, mode='min', verbose=True)
#     model_checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', verbose=True)

#     trainer = pl.Trainer(max_epochs=10, accellerator='auto', callbacks=[early_stopping, model_checkpoint])
#     trainer.fit(model, train_loader, val_loader)

#     trainer.test(model, val_loader)

#     break


X = torch.randn(1000, 512)
y = torch.randint(0, 10, (1000,))


dataset = TensorDataset(X, y)

train_dataset, val_dataset = random_split(dataset, [800, 200])

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)

model = LitClassifier()

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
early_stopping = EarlyStopping('val_loss', patience=3, mode='min', verbose=True)
model_checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', verbose=True)

trainer = pl.Trainer(max_epochs=10, callbacks=[early_stopping, model_checkpoint])
trainer.fit(model, train_loader, val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                | Params
--------------------------------------------------
0 | model     | Classifier          | 325 K 
1 | loss      | CrossEntropyLoss    | 0     
2 | accuracy  | MulticlassAccuracy  | 0     
3 | f1        | MulticlassF1Score   | 0     
4 | precision | MulticlassPrecision | 0     
5 | recall    | MulticlassRecall    | 0     
--------------------------------------------------
325 K     Trainable params
0         Non-trainable params
325 K     Total params
1.300     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

NotImplementedError: The operator 'aten::_unique2' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.