In [None]:
%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from customlib import d2l
from customlib import u2net
from customlib import unet
from customlib import loss_function as ls
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.metrics import jaccard_score
import numpy as np
from thop import profile
import zipfile
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("current device is",device,torch.cuda.get_device_name())
save_path='../model/'
work_name="240506_u2net+samenc+aug"

In [None]:
# batch_in, labels = list(iter(test_iter))[0]
# batch_out = net(batch_in.cuda())
# batch_out = torch.randn(3,2, 384, 384)
# labels = (torch.randn (3,384, 384) > 0).long()
# bce_dice_loss = ls.BCEDiceLoss()
# print("iouloss     =  " + str(ls.jaccard_loss(batch_out.cuda(), labels.cuda())))
# print("celoss      =  " + str(ls.ce_loss(batch_out.cuda(), labels.cuda()))) 
# print("diceloss    =  " + str(ls.dice_loss(batch_out.cuda(), labels.cuda()))) 
# print("focalloss   =  " + str(ls.focal_loss(batch_out.cuda(), labels.cuda())))
# print("tverskyloss =  " + str(ls.tversky_loss(batch_out.cuda(), labels.cuda())))
# print("ce_iou      =  " + str(ls.ceiou_loss(batch_out.cuda(), labels.cuda()))) 
# print("bce_dice      =  " + str(bce_dice_loss(batch_out.cuda(), labels.cuda())))

In [None]:
# net = fpn.FPNNet(num_classes=2)
net = u2net.u2net_full(out_ch=2)
net.eval()

In [None]:
net = torch.load(os.path.join(save_path,'{}'.format('best_240505_u2net+samenc+aug.pth')))
net.eval()
devices=d2l.try_all_gpus()
net.to(devices[0])

In [None]:
# Load the SAM model weights
sam_weights = torch.load('../model/sam_vit_b.pth')
from functools import partial
# Create the ImageEncoderViT model

state_dict_new = net.tfenc.state_dict()

# Create a new state dict in which we will load the weights
state_dict_to_load = {}

# Counters for successful and failed weight loads
success_count = 0
fail_count = 0

# Iterate over the old state dict
for name, param in sam_weights.items():
    # Check if the weight name starts with 'image_encoder'
    if name.startswith('image_encoder'):
        # Remove the 'image_encoder.' prefix
        new_name = name[len('image_encoder.'):]

        # If the layer exists in the tfenc and the shapes match, load the weights
        if new_name in state_dict_new and param.shape == state_dict_new[new_name].shape:
            state_dict_to_load[new_name] = param
            success_count += 1
        else:
            print(f'Failed to load weight: {name}, shape: {param.shape}')
            fail_count += 1

# Load the weights into the tfenc
net.tfenc.load_state_dict(state_dict_to_load, strict=False)

# Print the counts
print(f'Successfully loaded weights: {success_count}')
print(f'Failed to load weights: {fail_count}')


In [None]:
# Load the old model
model_old = torch.load(os.path.join(save_path,'{}.pth'.format("best_240505_u2net+samenc+aug")))

# Get the state dict of the old model
state_dict_old = model_old.state_dict()

# If the model was trained using DataParallel, remove the 'module.' prefix
state_dict_old = {k.replace('module.', ''): v for k, v in state_dict_old.items()}

# Get the state dict of the new model
state_dict_new = net.state_dict()

# Create a new state dict in which we will load the weights
state_dict_to_load = {}

# Counters for successful and failed weight loads
success_count = 0
fail_count = 0

# Iterate over the old state dict
for name, param in state_dict_old.items():
    # If the layer exists in the new model and the shapes match, load the weights
    if name in state_dict_new and param.shape == state_dict_new[name].shape:
        state_dict_to_load[name] = param
        success_count += 1
    else:
        # print(f'Failed to load weight: {name}, shape: {param.shape}')
        fail_count += 1

# Load the weights into the new model
net.load_state_dict(state_dict_to_load, strict=False)

# Print the counts
print(f'Successfully loaded weights: {success_count}')
print(f'Failed to load weights: {fail_count}')

In [None]:
X = torch.rand(size=(1, 3, 384, 384))
flops, params = profile(net, inputs=(X,))
print(net(X).shape)
print("参数量：", params/1e6, "M")
print("FLOPS：", flops/1e9, "G")

In [4]:
batch_size, crop_size = 4, (384, 384)
train_iter,test_iter = d2l.load_data_voc(batch_size, crop_size,num_workers=4, is_new=True)
# train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size,num_workers=0,is_new=False)
print(train_iter.dataset[0][0].shape)
print(test_iter.dataset[0][0].shape)

In [None]:
num_epochs, lr, wd, devices = 200, 0.0001, 1e-4, d2l.try_all_gpus()
savebest = 1
# trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
trainer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, T_max=50)
# trainer = torch.optim.RMSprop(net.parameters(), lr=0.001, alpha=0.99, eps=1e-08, weight_decay=wd)
d2l.train_ch13_txt(net, train_iter, test_iter, ls.focal_rce_loss, trainer,scheduler, num_epochs, os.path.join(save_path,'{}.txt'.format(work_name)), work_name,savebest ,devices)