In [7]:

import torch
from torch import nn
torch.cuda.empty_cache()
from torch.nn.functional import interpolate

from hr_dv2 import HighResDV2
import hr_dv2.transform as tr

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from time import time_ns

torch.manual_seed(0)
np.random.seed(0)

use_norm = True

In [8]:
def load_img(path: str, l: int) -> tuple[torch.Tensor, np.ndarray]:
    temp_img = Image.open(path)
    h, w = temp_img.height, temp_img.width
    #transform = tr.closest_crop(h, w) #tr.get_input_transform(L, L)
    transform = tr.get_input_transform(l, l)
    tensor, img = tr.load_image(path, transform)
    H, W = img.height, img.width
    return tensor, np.array(img)

In [9]:
def measure_mem_time(inp: torch.Tensor, model: nn.Module, seq: bool = False) -> tuple[float, float]:
    if type(model) == HighResDV2:
        inp = inp.squeeze(0)

    torch.cuda.reset_peak_memory_stats() # s.t memory is accurate
    torch.cuda.synchronize() # s.t time is accurate
    def _to_MB(x: int) -> float:
        return x / (1024**2)

    def _to_s(t: int) -> float:
        return t / 1e9

    start_m = torch.cuda.max_memory_allocated()
    start_t = time_ns()
    
    if seq:
        model.forward_sequential(inp)
    else:
        model.forward(inp)

    end_m = torch.cuda.max_memory_allocated()
    torch.cuda.synchronize()
    end_t = time_ns()

    return _to_MB(end_m - start_m), _to_s(end_t - start_t)

In [10]:
net = HighResDV2("dinov2_vits14_reg", 4, dtype=torch.float16) #dino_vits8 #dinov2_vits14_reg
net.interpolation_mode = 'nearest-exact'
net.eval()
net.cuda()
net.half()
None

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dinov2_main


In [11]:
img_tensor, img_arr = load_img('../fig_data/1.jpg', (350, 350))
img_tensor = img_tensor.cuda().unsqueeze(0)
img_tensor = img_tensor.half()

In [12]:
for i in range(5):
    net.forward(img_tensor.squeeze(0))


In [13]:

flip = tr.get_flip_transforms()
no_trs = ([], [])
moore_1 = tr.get_shift_transforms([1], 'Moore')
neumann_1 = tr.get_shift_transforms([1], 'Neumann')
moore_2 = tr.get_shift_transforms([1, 2], 'Moore')
moore_4 = tr.get_shift_transforms([1, 2, 3, 4], 'Moore')

moore_2_flip = tr.combine_transforms(moore_2[0], flip[0], moore_2[1], flip[1])
moore_4_flip = tr.combine_transforms(moore_4[0], flip[0], moore_4[1], flip[1])


names = ['no_trs', 'moore_1', 'neumann_1' , 'moore_2', 'moore_2_flip', 'moore_4', 'moore_4_flip']
fwd_inv_transforms: list[tuple[tr.PartialTrs, tr.PartialTrs]] = [no_trs, moore_1, neumann_1, moore_2, moore_2_flip, moore_4, moore_4_flip]

for is_seq in (False, True):
    prefix = "Sequential" if is_seq else "Batch"
    print(f"{prefix.upper()}\n")
    for i, (fwd, inv) in enumerate(fwd_inv_transforms):
        n_t = len(fwd)
        net.set_transforms(fwd, inv)
        mem, time = measure_mem_time(img_tensor, net, is_seq)
        print(f"\t{names[i]}, N_t={n_t}: {mem:.2f} MB, {time:.2f} s")


BATCH

	no_trs, N_t=0: 281.29 MB, 0.03 s
	moore_1, N_t=9: 551.13 MB, 0.21 s
	neumann_1, N_t=5: 505.62 MB, 0.12 s
	moore_2, N_t=17: 1002.24 MB, 0.42 s
	moore_2_flip, N_t=68: 4008.97 MB, 1.54 s
	moore_4, N_t=33: 1945.37 MB, 0.76 s
	moore_4_flip, N_t=132: 7781.51 MB, 3.03 s
SEQUENTIAL

	no_trs, N_t=1: 281.29 MB, 0.03 s
	moore_1, N_t=9: 466.34 MB, 0.25 s
	neumann_1, N_t=5: 463.54 MB, 0.14 s
	moore_2, N_t=17: 472.03 MB, 0.48 s
	moore_2_flip, N_t=68: 508.03 MB, 1.94 s
	moore_4, N_t=33: 483.17 MB, 0.94 s
	moore_4_flip, N_t=132: 552.56 MB, 3.78 s


In [21]:
dino_net = HighResDV2("dino_vits8", 4, dtype=torch.float16)
dino_net.interpolation_mode = 'nearest-exact'
dino_net.eval()
dino_net.cuda()
dino_net.half()

dino_net.set_transforms(fwd, inv)
mem, time = measure_mem_time(img_tensor, dino_net, False)
print(mem, time)

Using cache found in /home/ronan/.cache/torch/hub/facebookresearch_dino_main


7959.06201171875 3.02226084


In [23]:
for is_seq in (False, True):
    prefix = "Sequential" if is_seq else "Batch"
    print(f"{prefix.upper()}\n")
    for i, (fwd, inv) in enumerate(fwd_inv_transforms):
        n_t = len(fwd)
        dino_net.set_transforms(fwd, inv)
        mem, time = measure_mem_time(img_tensor, dino_net, is_seq)
        print(f"\t{names[i]}, N_t={n_t}: {mem:.2f} MB, {time:.2f} s")

BATCH

	no_trs, N_t=1: 281.95 MB, 0.03 s
	moore_1, N_t=9: 552.99 MB, 0.22 s
	neumann_1, N_t=5: 506.85 MB, 0.13 s
	moore_2, N_t=17: 1025.12 MB, 0.40 s
	moore_2_flip, N_t=68: 4100.46 MB, 1.55 s
	moore_4, N_t=33: 1989.77 MB, 0.77 s
	moore_4_flip, N_t=132: 7959.06 MB, 3.04 s
SEQUENTIAL

	no_trs, N_t=1: 281.95 MB, 0.03 s
	moore_1, N_t=9: 466.92 MB, 0.26 s
	neumann_1, N_t=5: 464.20 MB, 0.14 s
	moore_2, N_t=17: 472.69 MB, 0.49 s
	moore_2_flip, N_t=68: 508.69 MB, 1.98 s
	moore_4, N_t=33: 483.83 MB, 0.96 s
	moore_4_flip, N_t=132: 552.81 MB, 3.85 s
