In [60]:
import sys
from PIL import Image
import glob
import time
from pathlib import Path
from typing import Iterable,Optional
import math
import torch
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import torch.nn as nn
import torchvision
import argparse
import os
import timm
from timm.utils import accuracy
from torch.utils.tensorboard import SummaryWriter
from util import misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_args_parser():
    parser = argparse.ArgumentParser("Mae pre-training",add_help=False)
    parser.add_argument('--batch_size',default=32,type=int,help='Batch size per GPU (effective batch size is batch_size*accum_iter* #gpus)')
    parser.add_argument('--epochs',default=400,type=int)
    parser.add_argument('--accum_iter',default=1,type=int)
    #Model parameters
    parser.add_argument('--input_size',default=128,type=int,help='images input size')
    #Optimizer parameters
    parser.add_argument('--weight_decay',type=float,default=0.0001)
    parser.add_argument('--lr',type=float,default=0.0001,metavar='LR')
    parser.add_argument('--root_path',default='./dataset_fruit_veg')
    parser.add_argument('--output_dir',default='./output_dir_pretrained',help='path to save,empty for no saving')
    parser.add_argument('--log_dir',default='./output_dir_pretrained',help='path to tensorboard log')
    
    parser.add_argument('--resume',default='',help='resume from checkpoint')
    parser.add_argument('--start_epoch',default=0,type=int,metavar='N')
    parser.add_argument('--num_workers',default=5,type=int)
    parser.add_argument('--pin_mem',action='store_true')
    parser.add_argument('--no_pin_mem',action='store_false',dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    return parser
    
'''创建预处理的transform'''
def build_transform(is_train,args):
    if is_train:
        print("train transform")
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize((args.input_size,args.input_size)),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.RandomPerspective(distortion_scale=0.6,p=1.0),
            torchvision.transforms.GaussianBlur(kernel_size=(5,9),sigma=(0.1,5)),
            torchvision.transforms.ToTensor()
        ])
    else:
        print("eval transform")
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize((args.input_size,args.input_size)),
            torchvision.transforms.ToTensor()
        ])
'''创建数据集 返回dataset'''
def build_dataset(is_train,args):
    transform = build_transform(is_train,args)
    path = os.path.join(args.root_path,'train' if is_train else 'test')
    dataset = torchvision.datasets.ImageFolder(path,transform= transform)
    info = dataset.find_classes(path)
    #print(f"finding classes from {path}: {info[0]}")
    print(f"mapping classes from {path} to indexes:{info[1]}")
    return dataset

''' 验证函数
    输入：data_loader，model，device
    输出：test_stats{top1,top5,loss}
'''
@torch.no_grad()
def evaluate(data_loader,model,device):
    criterion = torch.nn.CrossEntropyLoss()
    metric_logger = misc.MetricLogger(delimiter=" ")
    header = 'Test:'
    model = model.to(device)
    model.eval()
    #下面这段话基本等价于 for (images, targets) in data_loader：
    for batch in metric_logger.log_every(data_loader,
                                         100, #打印间隔
                                         header): #标题
        images = batch[0]
        target = batch[-1]
        images = images.to(device,non_blocking=True)
        target = target.to(device,non_blocking=True)
        output = model(images)
        loss = criterion(output,target)
        output = torch.nn.functional.softmax(output,dim=1) 
        acc1,acc5 = accuracy(output,target,topk=(1,5)) #top1正确率和top5正确率
        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(),n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(),n=batch_size)
        
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1,top5=metric_logger.acc5,losses=metric_logger.loss))
    return {k:meter.global_avg for k,meter in metric_logger.meters.items()}

'''
    训练函数
'''
def train_one_epoch(model:torch.nn.Module,criterion:torch.nn.Module,
                    data_loader:Iterable,optimizer:torch.optim.Optimizer,
                    device:torch.device,epoch:int,loss_scaler,max_norm: float=0,
                    log_writer=None,args=None):
    model.train(True)
    print_freq = 2
    accum_iter = args.accum_iter
 
    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))
    for data_iter_step,(samples,targets) in enumerate(data_loader):
        samples = samples.to(device,non_blocking=True)
        targets = targets.to(device,non_blocking=True)
        
        outputs = model(samples)
        warmup_lr = args.lr
        optimizer.param_groups[0]["lr"] = warmup_lr
        
        loss = criterion(outputs,targets)
        loss /= accum_iter
        
        loss_scaler(loss,optimizer,clip_grad=max_norm, 
                    parameters=model.parameters(),create_graph=False,
                    update_grad=(data_iter_step+1)%accum_iter == 0) #训练每accum_iter个batch才更新梯度
        loss_value = loss.item()
        if (data_iter_step+1)%accum_iter == 0:
            optimizer.zero_grad()
        if not math.isfinite(loss_value):
            print(f"loss is {loss_value}, stopping training")
            sys.exit(1)
        if log_writer is not None and (data_iter_step+1)%accum_iter == 0:
            epoch_1000x = int((data_iter_step/len(data_loader)+epoch)*1000)
            log_writer.add_scalar('loss',loss_value,epoch_1000x)
            log_writer.add_scalar('lr',warmup_lr,epoch_1000x)
            print(f"Epoch: {epoch}, Step: {data_iter_step}, Loss: {loss}, Lr: {warmup_lr}")
            

def main(args,mode='train',test_image_path=''):
    print(f"当前mode: {mode}")
    if mode =='train':
        #构建批次
        dataset_train = build_dataset(is_train=True,args=args)
        dataset_val = build_dataset(is_train=False,args=args)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        
        data_loader_train = torch.utils.data.DataLoader(
            dataset=dataset_train,sampler=sampler_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=True,
        )
        data_loader_val = torch.utils.data.DataLoader(
            dataset=dataset_val,sampler=sampler_val,
            #batch_size=1,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False,
        )
        
        #构建模型
        model = timm.create_model('resnet18',pretrained=True,num_classes=36,drop_rate=0.1,drop_path_rate=0.1)
        model = model.to(device)
        n_parameters = sum([p.numel() for p in model.parameters() if p.requires_grad])

        print(f"number of trainable parameters(M):{n_parameters/1.e6:.2f}") #f-string保留两位小数{xxx:.2f}
        criterion = torch.nn.CrossEntropyLoss()
           
        #weight_decay就是对损失函数做L2正则化，防止过拟合
        optimizer = torch.optim.AdamW(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)     
        #用tensorboard记录日志
        os.makedirs(args.log_dir,exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
        
        #lossScaler用来反传梯度用的
        loss_scaler = NativeScaler()
        
        #读入已有的模型 resume为空字符串 则不会读取，如果传入时pth文件，则会读取原来的模型
        #读取进来时args里epoch会+1
        misc.load_model(args=args,model_without_ddp=model,optimizer=optimizer,loss_scaler=loss_scaler)
        
        for epoch in range(args.start_epoch,args.epochs): #start_epoch开始训练
            print(f"Epoch {epoch}")
            print(f"length of data_loader_train is {len(data_loader_train)}") #几个batch
            
            if epoch % 1 == 0:
                print("Evaluating...")
                model.eval()
                test_stats = evaluate(data_loader_val,model,device)
                print(f"Accuracy on the {len(dataset_val)} test images {test_stats['acc1']:.2f}")
                if log_writer is not None:
                    
                    ''' 
                        add_scalar(tag, scalar_value, global_step=None, walltime=None) 
                        add_scalar:记录标量函数,参数:
                        tag：图的名称 scalar_value：记录的值 global_step：x轴
                    '''
                    log_writer.add_scalar('perf/test_acc1',test_stats['acc1'],epoch)
                    log_writer.add_scalar('perf/test_acc5',test_stats['acc5'],epoch)
                    log_writer.add_scalar('perf/test_loss',test_stats['loss'],epoch)
                model.train()
                
            print("Training...")
            train_stats = train_one_epoch(
                model,criterion,data_loader_train,
                optimizer,device,epoch,#epoch+1, #为什么要+1？
                loss_scaler,None,
                log_writer=log_writer,args=args
            )
            if args.output_dir:
                print("Saving checkpoint...")
                misc.save_model(args=args,model=model,model_without_ddp=model,optimizer=optimizer,
                               loss_scaler=loss_scaler,epoch=epoch)
        
    else: #infer模式
        #pretrained=False加载自己的模型
        model = timm.create_model('resnet18',pretrained=False,num_classes=36,drop_rate=0.1,drop_path_rate=0.1)
        class_dict = {'apple': 0, 'banana': 1, 'beetroot': 2, 'bell pepper': 3, 'cabbage': 4, 'capsicum': 5, 'carrot': 6, 'cauliflower': 7, 'chilli pepper': 8, 'corn': 9, 'cucumber': 10, 'eggplant': 11, 'garlic': 12, 'ginger': 13, 'grapes': 14, 'jalepeno': 15, 'kiwi': 16, 'lemon': 17, 'lettuce': 18, 'mango': 19, 'onion': 20, 'orange': 21, 'paprika': 22, 'pear': 23, 'peas': 24, 'pineapple': 25, 'pomegranate': 26, 'potato': 27, 'raddish': 28, 'soy beans': 29, 'spinach': 30, 'sweetcorn': 31, 'sweetpotato': 32, 'tomato': 33, 'turnip': 34, 'watermelon': 35}
        n_parameters = sum([p.numel() for p in model.parameters() if p.requires_grad])
        print(f"number of trainable parameters(M):{n_parameters/1.e6:.2f}") #f-string保留两位小数{xxx:.2f}
        optimizer = torch.optim.AdamW(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)     
        loss_scaler = NativeScaler()
        misc.load_model(args=args,model_without_ddp=model,optimizer=optimizer,loss_scaler=loss_scaler)
        model.eval()
        
        image = Image.open(test_image_path).convert('RGB')#test_image_path是main函数的参数
        image =image.resize((args.input_size,args.input_size),Image.ANTIALIAS)
        
        #ToTensor()是创建了一个类实例 ，类实例(image)是调用_call_方法
        image = torchvision.transforms.ToTensor()(image).unsqueeze(0)
        
        with torch.no_grad():
            output = model(image)
        output = torch.nn.functional.softmax(output,dim=-1)
        class_idx = torch.argmax(output,dim=-1)[0] # softmax之后是一个(bs,)的tensor
        score = torch.max(output,dim=-1).values[0] #max返回{values, indices}的字典
        
        print(f"test image path is {test_image_path}")
        print(f"score is {score.item()}, class id is {class_idx.item()}, class name is {list(class_dict.keys())[list(class_dict.values()).index(class_idx)]}")
        time.sleep(0.5)
if __name__=='__main__':
    args = get_args_parser()
    #jupyter和命令行中不一样，命令行可以通过 python train.py --batch_size 32 ..,传入参数
    #但是jupyter不行，需要自己给args赋值一个列表，模仿命令行读入的参数，然后让argparse解析
    #args = args.parse_args(args=['--batch_size','128','--epochs','3','--num_workers','2'])
    #resume参数加载旧模型时的参数
    args = args.parse_args(args=['--batch_size','64','--epochs','20','--num_workers','2','--resume','output_dir_pretrained/checkpoint-19.pth'])
    
    print(args)
    
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True,exist_ok=True)
    #mode = 'train' #infer or train
    mode = 'infer' #infer or train
    if mode =='train':
        main(args,mode=mode)
    else:
        images = glob.glob('dataset_fruit_veg/test/*/*.jpg') #测试图片的的路径
        #print('test images: ',images)
        for image in images:
            print('\n')
            main(args,mode = mode,test_image_path=image)
    #build_dataset(is_train=True,args=args)
    #main(args=args)
    

Namespace(batch_size=64, epochs=20, accum_iter=1, input_size=128, weight_decay=0.0001, lr=0.0001, root_path='./dataset_fruit_veg', output_dir='./output_dir_pretrained', log_dir='./output_dir_pretrained', resume='output_dir_pretrained/checkpoint-19.pth', start_epoch=0, num_workers=2, pin_mem=True)


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!


  image =image.resize((args.input_size,args.input_size),Image.ANTIALIAS)


test image path is dataset_fruit_veg/test\apple\Image_24.jpg
score is 0.29847773909568787, class id is 0, class name is apple


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\apple\Image_25.jpg
score is 0.508858323097229, class id is 26, class name is pomegranate


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\apple\Image_33.jpg
score is 0.48444199562072754, class id is 26, class name is pomegranate


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\apple\Image_42.jpg
score is 0.5927813649177551, class id is 26, class name is pomegranate


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint ou

test image path is dataset_fruit_veg/test\bell pepper\Image_33.jpg
score is 0.33190035820007324, class id is 5, class name is capsicum


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\bell pepper\Image_4.jpg
score is 0.3875906467437744, class id is 33, class name is tomato


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\bell pepper\Image_5.jpg
score is 0.37950196862220764, class id is 3, class name is bell pepper


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\bell pepper\Image_52.jpg
score is 0.3361292779445648, class id is 3, class name is bell pepper


当前mode: infer
number of trainable parameters(M):11.19
R

test image path is dataset_fruit_veg/test\carrot\Image_61.jpg
score is 0.6956587433815002, class id is 6, class name is carrot


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\carrot\Image_71.jpg
score is 0.20041324198246002, class id is 32, class name is sweetpotato


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\carrot\Image_8.jpg
score is 0.8273293972015381, class id is 6, class name is carrot


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\carrot\Image_83.jpg
score is 0.3404100239276886, class id is 31, class name is sweetcorn


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output

test image path is dataset_fruit_veg/test\corn\Image_43.jpg
score is 0.23707476258277893, class id is 29, class name is soy beans


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\corn\Image_52.jpg
score is 0.439523845911026, class id is 31, class name is sweetcorn


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\corn\Image_98.jpg
score is 0.36249858140945435, class id is 9, class name is corn


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\cucumber\Image_11.jpg
score is 0.32202887535095215, class id is 10, class name is cucumber


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_di

test image path is dataset_fruit_veg/test\garlic\Image_86.jpg
score is 0.2980020046234131, class id is 34, class name is turnip


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\garlic\Image_96.jpg
score is 0.8135331273078918, class id is 12, class name is garlic


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\ginger\Image_16.jpg
score is 0.752241313457489, class id is 13, class name is ginger


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\ginger\Image_19.jpg
score is 0.4653821289539337, class id is 7, class name is cauliflower


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_di



当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\kiwi\Image_31.jpg
score is 0.9359644055366516, class id is 16, class name is kiwi


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\kiwi\Image_42.jpg
score is 0.6852593421936035, class id is 16, class name is kiwi


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\kiwi\Image_44.jpg
score is 0.36552420258522034, class id is 24, class name is peas


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\kiwi\Image_69.jpg
score is 0.23128926753997803, c

test image path is dataset_fruit_veg/test\mango\Image_39.jpg
score is 0.24164451658725739, class id is 21, class name is orange


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\mango\Image_4.jpg
score is 0.13200688362121582, class id is 9, class name is corn


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\mango\Image_48.jpg
score is 0.21371448040008545, class id is 19, class name is mango


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\mango\Image_65.jpg
score is 0.3448849618434906, class id is 35, class name is watermelon


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pre



当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\paprika\Image_66.jpg
score is 0.3235718607902527, class id is 32, class name is sweetpotato


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\paprika\Image_81.jpg
score is 0.4292537569999695, class id is 35, class name is watermelon


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\paprika\Image_88.jpg
score is 0.12270262092351913, class id is 2, class name is beetroot


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\paprika\Image_94.jpg
sco

test image path is dataset_fruit_veg/test\pineapple\Image_83.jpg
score is 0.620139479637146, class id is 25, class name is pineapple


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\pineapple\Image_90.jpg
score is 0.10655967891216278, class id is 4, class name is cabbage


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\pomegranate\Image_10.jpg
score is 0.47969236969947815, class id is 26, class name is pomegranate


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\pomegranate\Image_11.jpg
score is 0.21734723448753357, class id is 22, class name is paprika


当前mode: infer
number of trainable parameters(M):11.19
Res

test image path is dataset_fruit_veg/test\raddish\Image_96.jpg
score is 0.29152652621269226, class id is 12, class name is garlic


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\soy beans\Image_29.jpg
score is 0.32663461565971375, class id is 29, class name is soy beans


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\soy beans\Image_4.jpg
score is 0.8000284433364868, class id is 29, class name is soy beans


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\soy beans\Image_41.jpg
score is 0.37919551134109497, class id is 30, class name is spinach


当前mode: infer
number of trainable parameters(M):11.19
Resume chec

test image path is dataset_fruit_veg/test\sweetpotato\Image_20.jpg
score is 0.5219218134880066, class id is 32, class name is sweetpotato


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\sweetpotato\Image_24.jpg
score is 0.23031564056873322, class id is 11, class name is eggplant


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\sweetpotato\Image_42.jpg
score is 0.3974957764148712, class id is 32, class name is sweetpotato


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\sweetpotato\Image_45.jpg
score is 0.19609618186950684, class id is 27, class name is potato


当前mode: infer
number of trainable parameters(M):11

test image path is dataset_fruit_veg/test\watermelon\Image_51.jpg
score is 0.6818152666091919, class id is 35, class name is watermelon


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\watermelon\Image_55.jpg
score is 0.35423508286476135, class id is 35, class name is watermelon


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\watermelon\Image_60.jpg
score is 0.4138772487640381, class id is 35, class name is watermelon


当前mode: infer
number of trainable parameters(M):11.19
Resume checkpoint output_dir_pretrained/checkpoint-19.pth
With optim & sched!
test image path is dataset_fruit_veg/test\watermelon\Image_91.jpg
score is 0.5299580693244934, class id is 35, class name is watermelon


In [20]:
a = torch.randint(0,10,(1,5))
print(a)
b = torch.argmax(a,dim=1)
c = torch.max(a,dim=1)
c

tensor([[4, 4, 2, 4, 6]])


torch.return_types.max(
values=tensor([6]),
indices=tensor([4]))

In [35]:
class_dict = {'apple': 0, 'banana': 1, 'beetroot': 2, 'bell pepper': 3, 'cabbage': 4, 'capsicum': 5, 'carrot': 6, 'cauliflower': 7, 'chilli pepper': 8, 'corn': 9, 'cucumber': 10, 'eggplant': 11, 'garlic': 12, 'ginger': 13, 'grapes': 14, 'jalepeno': 15, 'kiwi': 16, 'lemon': 17, 'lettuce': 18, 'mango': 19, 'onion': 20, 'orange': 21, 'paprika': 22, 'pear': 23, 'peas': 24, 'pineapple': 25, 'pomegranate': 26, 'potato': 27, 'raddish': 28, 'soy beans': 29, 'spinach': 30, 'sweetcorn': 31, 'sweetpotato': 32, 'tomato': 33, 'turnip': 34, 'watermelon': 35}
class_dict.keys()
class_idx = 1
list(class_dict.keys()) [list(class_dict.values()).index(class_idx)]

'banana'