-
Notifications
You must be signed in to change notification settings - Fork 0
/
transferLearning.py
executable file
·115 lines (93 loc) · 3.35 KB
/
transferLearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.optimizer import Optimizer
from torchmetrics import Accuracy
from BEATs.BEATs import BEATs, BEATsConfig
class BEATsTransferLearningModel(pl.LightningModule):
def __init__(
self,
num_target_classes: int = 50,
milestones: int = 5,
batch_size: int = 32,
lr: float = 1e-3,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6,
model_path: str = "/checkpoints/BEATs_iter3_plus_AS2M.pt",
**kwargs,
) -> None:
"""TransferLearningModel.
Args:
lr: Initial learning rate
"""
super().__init__()
self.lr = lr
self.lr_scheduler_gamma = lr_scheduler_gamma
self.num_workers = num_workers
self.batch_size = batch_size
self.milestones = milestones
self.num_target_classes = num_target_classes
# Initialise BEATs model
self.checkpoint = torch.load(model_path)
self.cfg = BEATsConfig(
{
**self.checkpoint["cfg"],
"predictor_class": self.num_target_classes,
"finetuned_model": False,
}
)
self._build_model()
self.train_acc = Accuracy(
task="multiclass", num_classes=self.num_target_classes
)
self.valid_acc = Accuracy(
task="multiclass", num_classes=self.num_target_classes
)
self.save_hyperparameters()
def _build_model(self):
# 1. Load the pre-trained network
self.beats = BEATs(self.cfg)
self.beats.load_state_dict(self.checkpoint["model"])
# 2. Classifier
self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class)
def forward(self, x, padding_mask=None):
"""Forward pass. Return x"""
# Get the representation
if padding_mask != None:
x, _ = self.beats.extract_features(x, padding_mask)
else:
x, _ = self.beats.extract_features(x)
# Get the logits
x = self.fc(x)
# Mean pool the second layer
x = x.mean(dim=1)
return x
def loss(self, lprobs, labels):
self.loss_func = nn.CrossEntropyLoss()
return self.loss_func(lprobs, labels)
def training_step(self, batch, batch_idx):
# 1. Forward pass:
x, padding_mask, y_true = batch
y_probs = self.forward(x, padding_mask)
# 2. Compute loss
train_loss = self.loss(y_probs, y_true)
# 3. Compute accuracy:
self.log("train_acc", self.train_acc(y_probs, y_true), prog_bar=True)
return train_loss
def validation_step(self, batch, batch_idx):
# 1. Forward pass:
x, padding_mask, y_true = batch
y_probs = self.forward(x)
# 2. Compute loss
self.log("val_loss", self.loss(y_probs, y_true), prog_bar=True)
# 3. Compute accuracy:
self.log("val_acc", self.valid_acc(y_probs, y_true), prog_bar=True)
def configure_optimizers(self):
optimizer = optim.AdamW(
[{"params": self.beats.parameters()}, {"params": self.fc.parameters()}],
lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
)
return optimizer