In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import timm
from EFF3D.MRI import *
from fastai.vision.all import *
from EFF3D.vit import MRIVisionTransformer
import torch.nn as nn
from sam import SAM

In [None]:
df=pd.read_csv('../nonorm/ADNCold.csv')
# path=Path('/home/staff/xin/Downloads/newMRI/ADtrain')

In [None]:
db = DataBlock(blocks=(TransformBlock(type_tfms=partial(MriTensorImage.create)),CategoryBlock),
               get_x=ColReader('name'),
               get_y=ColReader('label'),
               splitter=RandomSplitter(valid_pct=0.2,seed=2),
              )

dls=db.dataloaders(source=df, bs=12, num_workers=4)

In [None]:
sort=np.load('./sort.npy')
model=MRIVisionTransformer(mask=sort,num_heads=4,depth=6,num_patches=100)

In [None]:
class TstCallback(Callback):
    def before_fit(self):
        nowdlist=['pos_embed', 'cls_token', 'dist_token']
        for name,p in self.named_parameters():
            if name in nowdlist:
                self.opt.state[p]['do_wd'] = False

In [None]:
class SAM(Callback):
    "Sharpness-Aware Minimization"
    def __init__(self, zero_grad=True, rho=0.05, eps=1e-12, **kwargs): 
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        self.state = defaultdict(dict)
        store_attr()

    def params(self): return self.learn.opt.all_params(with_grad=True)
    def _grad_norm(self): return torch.norm(torch.stack([p.grad.norm(p=2) for p,*_ in self.params()]), p=2)
    
    @torch.no_grad()
    def first_step(self):
        scale = self.rho / (self._grad_norm() + self.eps)
        for p,*_ in self.params():
            self.state[p]["e_w"] = e_w = p.grad * scale
            p.add_(e_w)  # climb to the local maximum "w + e(w)"
        if self.zero_grad: self.learn.opt.zero_grad()
        
    @torch.no_grad()    
    def second_step(self):
        for p,*_ in self.params(): p.sub_(self.state[p]["e_w"])

    def before_step(self, **kwargs):
        self.first_step()
        self.learn.pred = self.model(*self.xb); self.learn('after_pred')
        self.loss_func(self.learn.pred, *self.yb).backward()
        self.second_step()

In [None]:
learn=Learner(dls,model=model,loss_func=LabelSmoothingCrossEntropy(),metrics=accuracy,cbs=[TstCallback,MixUp(0.4),SAM])

In [None]:
learn.fit_flat_cos(200,1e-4,cbs=[CSVLogger(fname='VIT100_4_6.csv',append=True),SaveModelCallback(monitor='accuracy',fname='VIT100_4_6')])

epoch,train_loss,valid_loss,accuracy,time
0,0.724009,0.908566,0.485981,02:51
1,0.720077,0.698872,0.556075,02:51
2,0.72656,0.716148,0.565421,02:50
3,0.709981,0.698515,0.434579,02:51
4,0.700636,0.682061,0.565421,02:50
5,0.70274,0.711685,0.434579,02:49
6,0.696305,0.705318,0.565421,02:49
7,0.696082,0.684875,0.574766,02:50
8,0.696421,0.70102,0.560748,02:50
9,0.684488,0.713593,0.560748,02:51


Better model found at epoch 0 with accuracy value: 0.4859813153743744.
Better model found at epoch 1 with accuracy value: 0.5560747385025024.
Better model found at epoch 2 with accuracy value: 0.5654205679893494.
Better model found at epoch 7 with accuracy value: 0.5747663378715515.
Better model found at epoch 11 with accuracy value: 0.5794392228126526.
Better model found at epoch 12 with accuracy value: 0.6028037667274475.
Better model found at epoch 15 with accuracy value: 0.6261682510375977.
Better model found at epoch 18 with accuracy value: 0.644859790802002.
Better model found at epoch 19 with accuracy value: 0.6542056202888489.
Better model found at epoch 20 with accuracy value: 0.7056074738502502.
Better model found at epoch 23 with accuracy value: 0.7149532437324524.
Better model found at epoch 31 with accuracy value: 0.7383177280426025.
Better model found at epoch 32 with accuracy value: 0.7429906725883484.
Better model found at epoch 34 with accuracy value: 0.757009327411651

In [None]:
learn.fit_flat_cos(200,1e-4,cbs=[CSVLogger(fname='VIT100_4_6.csv',append=True),SaveModelCallback(monitor='accuracy',fname='VIT100_4_6')])

epoch,train_loss,valid_loss,accuracy,time
0,0.314215,0.472042,0.82243,02:51
1,0.320483,0.4866,0.831776,02:51


Better model found at epoch 0 with accuracy value: 0.822429895401001.
Better model found at epoch 1 with accuracy value: 0.8317757248878479.
