In [1]:
import argparse
import json
import numpy as np
import os
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import utils.misc as misc
from utils.misc import NativeScalerWithGradNormCount as NativeScaler
import swin_mae
from utils.engine_pretrain import train_one_epoch


In [2]:
class get_args_parser():
    batch_size = 96
    epochs = 400
    save_freq = 400
    checkpoint_encoder = ''
    checkpoint_decoder = ''
    data_path = ''
    mask_ratio = 0.75

    model = 'swin_mae'
    input_size = 224
    norm_pix_loss = False  # Use (per-patch) normalized pixels as targets for computing loss
    
    accum_iter = 1
    weight_decay = 0.05
    lr = 1e-3
    min_lr = 0
    warmup_epochs = 10

    output_dir = './output_dir/200epochs'
    log_dir = './output_dir/200epochs'
    device = 'cuda'
    seed = 42
    start_epoch = 0
    num_workers = 1
    pin_mem = True  # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.

args = get_args_parser

In [3]:
args.data_path = '/home/yuchen/Swin-MAE/land4sensor/image_pretrain'
args.batch_size = 48
args.epochs = 200
args.save_freq = 50
args.lr = 1e-2

# args.mask_ratio = 0.6

In [4]:
# Fixed random seeds
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)

# Set up training equipment
device = torch.device(args.device)
cudnn.benchmark = True

# Defining data augmentation
transform_train = transforms.Compose([
        transforms.Resize((args.input_size, args.input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])

# Set dataset
dataset_train = datasets.ImageFolder(args.data_path, transform=transform_train)
sampler_train = torch.utils.data.RandomSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True
)


In [5]:
for batch in data_loader_train:
    break

print(batch[0].shape)
print(batch[1].shape)

torch.Size([48, 3, 224, 224])
torch.Size([48])


In [6]:
# Log output
if args.log_dir is not None:
    os.makedirs(args.log_dir, exist_ok=True)
    log_writer = SummaryWriter()
else: log_writer = None

# Set model
model = swin_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, mask_ratio=args.mask_ratio)
model.to(device)
model_without_ddp = model

# Set optimizer
param_groups = [p for p in model_without_ddp.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=5e-2, betas=(0.9, 0.95))  # 原来是5E-2
loss_scaler = NativeScaler()

# Create model
misc.load_model(args=args, model_without_ddp=model_without_ddp)

In [None]:
model_save_dir = f'/home/yuchen/Swin-MAE/output_dir/{args.epochs}epochs_01lr/'

if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)

print(f"Start training for {args.epochs} epochs")
for epoch in range(args.start_epoch, args.epochs):
    train_stats = train_one_epoch(
        model, data_loader_train,
        optimizer, device, epoch, loss_scaler,
        log_writer=log_writer,
        args=args
    )
    
    if args.output_dir and ((epoch + 1) % args.save_freq == 0 or epoch + 1 == args.epochs):
        torch.save(model_without_ddp.state_dict(), f'{model_save_dir}swinmae{str(epoch+1)}.pth')
        
        # misc.save_model(
        #     args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
        #     loss_scaler=loss_scaler, epoch=epoch + 1)

    log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch, }

    if args.output_dir and misc.is_main_process():
        if log_writer is not None:
            log_writer.flush()
        with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
            f.write(json.dumps(log_stats) + "\n")

Start training for 200 epochs
log_dir: runs/Mar09_19-28-57_my-beloved-pc
Epoch: [0]  [ 0/48]  eta: 0:01:30  lr: 0.000000  loss: 1.3560 (1.3560)  time: 1.8891  data: 0.1041  max mem: 5324
Epoch: [0]  [10/48]  eta: 0:00:17  lr: 0.000208  loss: 0.6667 (0.7460)  time: 0.4698  data: 0.0095  max mem: 5539
Epoch: [0]  [20/48]  eta: 0:00:11  lr: 0.000417  loss: 0.2033 (0.4386)  time: 0.3278  data: 0.0000  max mem: 5539
Epoch: [0]  [30/48]  eta: 0:00:06  lr: 0.000625  loss: 0.0503 (0.3091)  time: 0.3276  data: 0.0000  max mem: 5539
Epoch: [0]  [40/48]  eta: 0:00:02  lr: 0.000833  loss: 0.0310 (0.2402)  time: 0.3275  data: 0.0000  max mem: 5539
Epoch: [0]  [47/48]  eta: 0:00:00  lr: 0.000979  loss: 0.0240 (0.2082)  time: 0.3276  data: 0.0000  max mem: 5539
Epoch: [0] Total time: 0:00:17 (0.3608 s / it)
Averaged stats: lr: 0.000979  loss: 0.0240 (0.2082)
log_dir: runs/Mar09_19-28-57_my-beloved-pc
Epoch: [1]  [ 0/48]  eta: 0:00:20  lr: 0.001000  loss: 0.0199 (0.0199)  time: 0.4263  data: 0.0965  m

In [4]:
# import os
# import shutil

# # Define the source and target directories
# source_dir = '/home/yuchen/Swin-MAE/land4sensor/resplit_data'
# target_dir = '/home/yuchen/Swin-MAE/land4sensor/image_pretrain/landslides'

# # Check if target directory exists, create if it doesn't
# if not os.path.exists(target_dir):
#     os.makedirs(target_dir)

# # Walk through the source directory
# for dirpath, dirnames, filenames in os.walk(source_dir):
#     for file in filenames:
#         # Check if the file is a JPG image
#         if file.lower().endswith('.jpg'):
#             # Construct full file paths
#             src_file_path = os.path.join(dirpath, file)
#             dst_file_path = os.path.join(target_dir, file)
            
#             # Copy the file to the target directory
#             shutil.copy2(src_file_path, dst_file_path)
#             print(f"Copied: {src_file_path} to {dst_file_path}")



In [9]:
import glob
file = glob.glob('/home/yuchen/Swin-MAE/land4sensor/image_pretrain/landslides/*.jpg')
print(len(file))

2324
