In [6]:
# file ni for CSRNET_CBAM and make sure pakai nama train_new.py kat new cell
# kalau ada masalah import, run code paling last
# and make sure kat train_new.py tukar csrnet. to csrnet_cbam


# and afer finish training, rename file checkpoint.pth and model_best.pth untuk elak overwritten

In [7]:
import torch.nn as nn
import torch
from torchvision import models
from utils import save_net,load_net


In [8]:
##!!! here want to try train part A with CBAM 

In [9]:
# this is with CBAM attention

class ChannelAttention(nn.Module):
    #global avg pooling and max pooling-> compress feature maps into 1x1xC
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
    #pass both tru an MLP (2 conv layer) and sum
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
    
        self.sigmoid = nn.Sigmoid() #gives weight per channel

        #multiply back into input x
        
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)
    


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7): #use 7x7 convo
        super(SpatialAttention, self).__init__()
        
        #concatenate them -> 2xHxW
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(x))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

    
    #sequence of convo and pooling layers
def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False):
    d_rate = 2 if dilation else 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

class CSRNet_CBAM(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet_CBAM, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]

        self.frontend = make_layers(self.frontend_feat)
        self.cbam = CBAM(512)  # Add CBAM after frontend output
        self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)

        if not load_weights:
            mod = models.vgg16(pretrained=True)
            self._initialize_weights()
            vgg_state = mod.features.state_dict()
            frontend_state = self.frontend.state_dict()
            matched_weights = {k: v for k, v in vgg_state.items() if k in frontend_state and v.size() == frontend_state[k].size()}
            frontend_state.update(matched_weights)
            self.frontend.load_state_dict(frontend_state)

    def forward(self, x):
        x = self.frontend(x)
        x = self.cbam(x)        # Apply CBAM here
        x = self.backend(x)
        x = self.output_layer(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [10]:
model = CSRNet_CBAM()
x = torch.rand((1,3,255,255))
model(x).shape



torch.Size([1, 1, 31, 31])

In [9]:
!python train.py A_train.json A_val.json 0 0

epoch 0, processed 0 samples, lr 0.0001000000
Epoch: [0][0/1540]	Time 0.763 (0.763)	Data 0.022 (0.022)	Loss 0.7575 (0.7575)	
Epoch: [0][100/1540]	Time 0.029 (0.032)	Data 0.011 (0.009)	Loss 0.5297 (0.6536)	
Epoch: [0][200/1540]	Time 0.031 (0.029)	Data 0.013 (0.008)	Loss 0.4086 (0.6803)	
Epoch: [0][300/1540]	Time 0.029 (0.028)	Data 0.011 (0.008)	Loss 0.6933 (0.6777)	
Epoch: [0][400/1540]	Time 0.031 (0.028)	Data 0.013 (0.008)	Loss 0.6385 (0.6776)	
Epoch: [0][500/1540]	Time 0.028 (0.028)	Data 0.010 (0.008)	Loss 0.6416 (0.6732)	
Epoch: [0][600/1540]	Time 0.023 (0.027)	Data 0.008 (0.008)	Loss 0.4777 (0.6786)	
Epoch: [0][700/1540]	Time 0.020 (0.027)	Data 0.005 (0.008)	Loss 0.7465 (0.6710)	
Epoch: [0][800/1540]	Time 0.027 (0.027)	Data 0.008 (0.008)	Loss 0.7143 (0.6688)	
Epoch: [0][900/1540]	Time 0.022 (0.027)	Data 0.008 (0.008)	Loss 1.0481 (0.6681)	
Epoch: [0][1000/1540]	Time 0.028 (0.027)	Data 0.010 (0.008)	Loss 0.6448 (0.6677)	
Epoch: [0][1100/1540]	Time 0.024 (0.027)	Data 0.005 (0.008)	Loss

In [1]:
mae_list= [46.300458854006735, 28.633038461822824, 29.124831179982607, 28.264239616000776, 30.35370098192667, 28.72313591868607, 24.615138506151965, 28.22432853266136, 28.41189213388974, 30.42600147994523, 26.061958735751123, 25.97636707541869, 27.20912330666768, 26.37120925274092, 25.520352481566754, 26.023320748633946, 24.481220953243295, 26.531166587908242, 29.420082996801003, 27.109580973988955, 27.164920826548155, 26.77189885955496, 25.986446783714687, 24.640596153809852, 25.61320318634977, 25.270464730016965, 27.071955965966293, 27.305191531623763, 25.020767172587288, 24.811029935620496, 24.002286901179048, 25.405529769425538, 25.68986033410141, 25.173276114709598, 25.755648249203396, 25.17408038660423, 26.3671206641443, 26.722872684911355, 24.033000208667872, 26.27278864752386, 26.32103514917118, 26.36484422388765, 25.917635239276688, 25.597050027748974, 26.621093740168305, 26.34609075919869, 30.243239353612527, 25.641882571977437, 26.79238508165497, 28.25446494584231]
rmse_list= [73.22245915584513, 51.01574573003138, 46.94024717374272, 50.02872388903026, 43.567703221053875, 50.234042291285775, 42.63019257108826, 50.48291568617541, 41.44955933492963, 55.12899767744609, 40.01687839875432, 43.720876174308884, 47.53450725618991, 38.90465258339431, 41.68452669900206, 38.813363173955345, 40.775691479617144, 41.135955842595855, 41.19709581854875, 38.534922881113275, 38.5268476380534, 38.661911592718546, 40.881966607227405, 38.85534734918878, 38.19450859110378, 37.78666325927992, 38.63187151944638, 37.44881551016399, 40.18735633845925, 38.4477908191656, 38.54487728950813, 37.55400685524825, 41.257644993221035, 38.69684178932022, 40.098730568839684, 39.31899467886701, 39.516217825184064, 40.12757416823165, 38.23392608015764, 39.98110571443449, 41.320886050416654, 38.54957199638684, 39.53415151306872, 41.00605082618708, 39.406956096813126, 40.66304154382625, 40.26801653991187, 37.77346060124608, 39.45777965015736, 38.572790279119324]
ssim_list= [0.567025819356486, 0.49691231308754574, 0.4647512297174826, 0.45517333011267724, 0.43700255033051233, 0.4184222675610295, 0.38981732607041436, 0.39554766233357574, 0.406358961207965, 0.40782039601010145, 0.4148335572447359, 0.41125894456948203, 0.42366768934370314, 0.42847230339173187, 0.41473574794291224, 0.4247520348614024, 0.40448172544081185, 0.42508455527197453, 0.45099801462642924, 0.4379604858989568, 0.43844754494649846, 0.45938711498201507, 0.44325226506928805, 0.4247773840024914, 0.43803719968832644, 0.43057986256695285, 0.4427977470853894, 0.4615517755605511, 0.41711608339677153, 0.43966611605329614, 0.41718908057538506, 0.4437148478842273, 0.4350352239286162, 0.43840442490331905, 0.434621858074493, 0.44053953267710727, 0.4571680053300464, 0.4540604424384451, 0.42666447105020594, 0.44640339227374065, 0.4443436589130421, 0.4420680164829972, 0.46093932324156317, 0.42505525306020814, 0.45784522278099943, 0.44406096468266754, 0.4583111243303289, 0.43651665055874694, 0.4530828404672367, 0.46760562876450645]
psnr_list= [20.474549878503858, 21.428932730684576, 20.79651668391277, 21.777693215104723, 20.708226828231027, 22.10522717544713, 21.646939464451112, 22.18431630577009, 19.87691333613445, 21.761755594273204, 20.598910528359955, 21.146107771961958, 21.078859599595216, 20.100792550549066, 20.64555714302456, 19.924562955640027, 20.512416323435676, 20.162181598624002, 18.913230148787353, 19.25675470312846, 19.47790723486045, 19.760959782551243, 20.38379797001475, 20.2715214699814, 20.073505608076903, 20.36186972844232, 19.20508482529945, 19.496959263516455, 20.634175772519455, 20.09793925039547, 20.587783956036127, 20.04398570601473, 20.0086894625241, 20.52893690718818, 20.009165085468094, 20.443279821848133, 19.73386303911504, 19.956672599635173, 20.455753803253174, 20.020340595048726, 20.56121643302367, 19.947740992319954, 20.36627188908685, 20.528801151157655, 20.107770688754997, 20.092816947661724, 18.846075520073015, 20.05842177400884, 19.873007233609858, 19.30557297185524]
print("\n==== Best Metrics Summary ====")
print(f"Best MAE:  {min(mae_list):.3f}")
print(f"Best RMSE: {min(rmse_list):.3f}")
print(f"Best SSIM: {max(ssim_list):.3f}")
print(f"Best PSNR: {max(psnr_list):.3f}")



==== Best Metrics Summary ====
Best MAE:  24.002
Best RMSE: 37.449
Best SSIM: 0.567
Best PSNR: 22.184


In [20]:
!python train_new_cbam.py A_train.json A_val.json 0 0

epoch 0, processed 0 samples, lr 0.0000100000
Epoch: [0][0/400]	Time 0.844 (0.844)	Data 0.024 (0.024)	Loss 19669.6270 (19669.6270)	
Epoch: [0][100/400]	Time 0.118 (0.120)	Data 0.072 (0.063)	Loss 86.9001 (24396.3087)	
Epoch: [0][200/400]	Time 0.065 (0.114)	Data 0.024 (0.063)	Loss 818.6946 (19387.0159)	
Epoch: [0][300/400]	Time 0.053 (0.112)	Data 0.031 (0.064)	Loss 415.1665 (21283.0213)	
 * MAE: 223.354, RMSE: 359.673, SSIM: 0.387, PSNR: 20.249
Saved best model at epoch 0 with MAE: 223.354, SSIM: 0.387, PSNR: 20.249
 * Best MAE so far: 223.354
epoch 1, processed 400 samples, lr 0.0000100000
Epoch: [1][0/400]	Time 0.049 (0.049)	Data 0.015 (0.015)	Loss 3208.8196 (3208.8196)	
Epoch: [1][100/400]	Time 0.119 (0.104)	Data 0.074 (0.068)	Loss 14836.8398 (8907.7379)	
Epoch: [1][200/400]	Time 0.119 (0.103)	Data 0.088 (0.068)	Loss 47.3031 (6886.0748)	
Epoch: [1][300/400]	Time 0.085 (0.105)	Data 0.043 (0.069)	Loss 16.8519 (9196.9944)	
 * MAE: 161.351, RMSE: 248.738, SSIM: 0.368, PSNR: 20.549
Saved b

In [22]:
!python train_new_cbam.py A_train.json A_val.json 0 0

epoch 0, processed 0 samples, lr 0.0000100000
Epoch: [0][0/400]	Time 0.834 (0.834)	Data 0.025 (0.025)	Loss 119.6613 (119.6613)	
Epoch: [0][100/400]	Time 0.121 (0.119)	Data 0.040 (0.052)	Loss 15.1689 (35.1250)	
Epoch: [0][200/400]	Time 0.108 (0.116)	Data 0.064 (0.056)	Loss 25.1037 (32.7417)	
Epoch: [0][300/400]	Time 0.133 (0.115)	Data 0.088 (0.061)	Loss 109.4440 (31.2043)	
 * MAE: 165.572, RMSE: 265.223, SSIM: 0.350, PSNR: 20.416
Saved best model at epoch 0 with MAE: 165.572, SSIM: 0.350, PSNR: 20.416
 * Best MAE so far: 165.572
epoch 1, processed 400 samples, lr 0.0000100000
Epoch: [1][0/400]	Time 0.039 (0.039)	Data 0.011 (0.011)	Loss 22.1194 (22.1194)	
Epoch: [1][100/400]	Time 0.077 (0.103)	Data 0.036 (0.068)	Loss 32.1309 (22.5128)	
Epoch: [1][200/400]	Time 0.106 (0.105)	Data 0.077 (0.069)	Loss 5.8503 (22.6947)	
Epoch: [1][300/400]	Time 0.104 (0.105)	Data 0.052 (0.068)	Loss 8.3459 (21.1546)	
 * MAE: 146.909, RMSE: 231.735, SSIM: 0.385, PSNR: 20.675
Saved best model at epoch 1 with MAE

In [24]:
!python train_new_cbam.py A_train.json A_val.json 0 0

epoch 0, processed 0 samples, lr 0.0010000000
Epoch: [0][0/400]	Time 0.756 (0.756)	Data 0.015 (0.015)	Loss 35.3127 (35.3127)	
Epoch: [0][100/400]	Time 0.115 (0.120)	Data 0.082 (0.061)	Loss 7.2819 (16388.1805)	
Epoch: [0][200/400]	Time 0.086 (0.116)	Data 0.040 (0.059)	Loss 34.3531 (8251.9153)	
Epoch: [0][300/400]	Time 0.110 (0.112)	Data 0.073 (0.060)	Loss 27.3122 (5520.6089)	
 * MAE: 330.275, RMSE: 500.495, SSIM: 0.217, PSNR: 18.744
Saved best model at epoch 0 with MAE: 330.275, SSIM: 0.217, PSNR: 18.744
 * Best MAE so far: 330.275
epoch 1, processed 400 samples, lr 0.0010000000
Epoch: [1][0/400]	Time 0.073 (0.073)	Data 0.020 (0.020)	Loss 31.5580 (31.5580)	
Epoch: [1][100/400]	Time 0.097 (0.106)	Data 0.070 (0.064)	Loss 1.0714 (31.1138)	
Epoch: [1][200/400]	Time 0.130 (0.108)	Data 0.080 (0.066)	Loss 37.0157 (29.7779)	
Epoch: [1][300/400]	Time 0.098 (0.107)	Data 0.052 (0.066)	Loss 19.3900 (31.2981)	
 * MAE: 330.384, RMSE: 500.462, SSIM: 0.217, PSNR: 18.743
 * Best MAE so far: 330.275
epoc

In [25]:
!python train_nocountloss.py A_train.json A_val.json 0 0

epoch 0, processed 0 samples, lr 0.0010000000
Epoch: [0][0/400]	Time 0.819 (0.819)	Data 0.026 (0.026)	Loss 0.3204 (0.3204)	
Epoch: [0][100/400]	Time 0.075 (0.114)	Data 0.033 (0.059)	Loss 0.3753 (1.2957)	
Epoch: [0][200/400]	Time 0.096 (0.109)	Data 0.063 (0.061)	Loss 0.4856 (0.9645)	
Epoch: [0][300/400]	Time 0.119 (0.108)	Data 0.077 (0.063)	Loss 0.5826 (0.8422)	
 * MAE: 466.614, RMSE: 671.145, SSIM: 0.420, PSNR: 18.748
Saved best model at epoch 0 with MAE: 466.614, SSIM: 0.420, PSNR: 18.748
 * Best MAE so far: 466.614
epoch 1, processed 400 samples, lr 0.0010000000
Epoch: [1][0/400]	Time 0.068 (0.068)	Data 0.029 (0.029)	Loss 0.6391 (0.6391)	
Epoch: [1][100/400]	Time 0.109 (0.103)	Data 0.080 (0.069)	Loss 0.4532 (0.6248)	
Epoch: [1][200/400]	Time 0.080 (0.102)	Data 0.036 (0.068)	Loss 0.2856 (0.6063)	
Epoch: [1][300/400]	Time 0.094 (0.104)	Data 0.077 (0.070)	Loss 0.8747 (0.5974)	
 * MAE: 434.297, RMSE: 647.870, SSIM: 0.432, PSNR: 18.976
Saved best model at epoch 1 with MAE: 434.297, SSIM: 

In [23]:
import importlib
import model
importlib.reload(model)  
import model
print(dir(model))

['CBAM', 'CSRNet_CBAM', 'ChannelAttention', 'SpatialAttention', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'make_layers', 'models', 'nn', 'torch']
