In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import importlib

In [7]:
import seesaw
import torch
import torch.nn as nn

In [8]:
import clip

In [10]:
from seesaw import *

In [11]:
import copy

In [12]:
def load_model():
    variant ="ViT-B/32"
    jit = False
    model,_ = clip.load(variant, device='cpu',  jit=jit)
    return model

In [13]:
mod = load_model()

In [18]:
class StringEncoder(object):
    def __init__(self, model, device):
        self.device = device #next(iter(clean_weights.items()))[1].device
        self.model = model.to(device)
        self.original_weights = copy.deepcopy(model.state_dict())
        self.reset() # dont update the original model
        
    def encode_string(self, string):
        model = self.model.eval()
        with torch.no_grad():
            ttext = clip.tokenize([string])
            text_features = model.encode_text(ttext.to(self.device))
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            return text_features.detach().cpu().numpy()
        
    def reset(self):
        self.model.load_state_dict(copy.deepcopy(self.original_weights))


# class TextFeatures(object):
#     def __init__(self, model, target_string, device):
#         self.target_string= target_string
#         self.device = device
        
#         self.s2id = {}
#         self.sids = []
#         self.s2id[target_string] = 0
#         self.vecs = []
        
#     def add(self, strings):
def get_text_features(self, actual_strings, target_string):        
    s2id = {}
    sids = []
    s2id[target_string] = 0
    for s in actual_strings:
        if s not in s2id:
            s2id[s] = len(s2id)

        sids.append(s2id[s])

    strings = [target_string] + actual_strings
    ustrings = list(s2id)
    stringids = torch.tensor([s2id[s] for s in actual_strings], dtype=torch.long).to(self.device)
    tstrings = clip.tokenize(ustrings)
    text_features = self.model.encode_text(tstrings.to(self.device))
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features, stringids, ustrings
        
def forward(self, imagevecs, actual_strings, target_string):
    ## uniquify strings    
    text_features, stringids, ustrings = get_text_features(self, actual_strings, target_string)

#    image_features = torch.from_numpy(imagevecs).type(text_features.dtype)
#    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
#    image_features = image_features.to(self.device)
    image_features = imagevecs
    scores = image_features @ text_features.t()
    
    assert scores.shape[0] == stringids.shape[0]
    return scores, stringids.to(self.device), ustrings

def forward2(self, imagevecs, actual_strings, target_string):
    text_features, stringids, ustrings = get_text_features(self, actual_strings, target_string)
    actual_vecs = text_features[stringids]
    sought_vec = text_features[0].reshape(1,-1)
    
    image_features = torch.from_numpy(imagevecs).type(text_features.dtype)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    image_features = image_features.to(self.device)

    search_score = image_features @ sought_vec.reshape(-1)
    confounder_score = (image_features * actual_vecs).sum(dim=1)
    return search_score, confounder_score
    
import torch.optim

class Updater(object):
    def __init__(self, se, lr, rounds=1, losstype='hinge'):
        self.se = se
        self.losstype=losstype
        self.opt = torch.optim.AdamW([{'params': se.model.ln_final.parameters()},
                          {'params':se.model.text_projection},
#                          {'params':se.model.transformer.parameters(), 'lr':lr*.01}
                                     ], lr=lr, weight_decay=0.)
#        self.opt = torch.optim.Adam@([{'params': se.model.parameters()}], lr=lr)
        self.rounds = rounds
        
    def update(self, imagevecs, actual_strings, target_string):
        se = self.se
        se.model.train()
        losstype = self.losstype
        opt = self.opt
        margin = .3

        def opt_closure():
            opt.zero_grad()            
            if losstype=='ce':
                scores, stringids, rawstrs = forward(se, imagevecs, actual_strings, target_string)
                # breakpoint()
                iidx = torch.arange(scores.shape[0]).long()
                actuals = scores[iidx, stringids]
                midx = scores.argmax(dim=1)
                maxes = scores[iidx, midx]                
            elif losstype=='hinge':
                #a,b = forward2(se, imagevecs, actual_strings, target_string)
                scores, stringids, rawstrs = forward(se, imagevecs, actual_strings, target_string)
                # breakpoint()
                iidx = torch.arange(scores.shape[0]).long()
                maxidx = scores.argmax(dim=1)
                
                actual_score = scores[iidx, stringids].reshape(-1,1)
                #max_score = scores[iidx, maxidx]
                
                
                #target_score = scores[:,0]
                losses1 = F.relu(- (actual_score - scores - margin))
                #losses2 = F.relu(- (actual_score - target_score - margin))
                #losses = torch.cat([losses1, losses2])
                losses = losses1
            else:
                assert False
            loss = losses.mean()
            #print(loss.detach().cpu())
            loss.backward()

        for _ in range(self.rounds):
            opt.step(opt_closure)



def closure(search_query, max_n, firsts, show_display=False, batch_size=10):
    sq = search_terms['objectnet'][search_query]
    se = StringEncoder()
    up = Updater(se, lr=.0001, rounds=1)
    bs = batch_size
    bfq = BoxFeedbackQuery(hdb, batch_size=bs, auto_fill_df=None)
    tvecs = []
    dbidxs = []
    accstrs = []
    gts = []
    while True:
        tvec = se.encode_string(sq)
        tvecs.append(tvec)
        idxbatch = bfq.query_stateful(mode='dot', vector=tvec, batch_size=bs)
        dbidxs.append(idxbatch)
        gtvals = ev.query_ground_truth[search_query][idxbatch].values
        gts.append(gtvals)
        if show_display:
            display(hdb.raw.show_images(idxbatch))
            display(gtvals)
        #vecs = ev.embedded_dataset[idxbatch]
        actual_strings = get_feedback(idxbatch)
        accstrs.extend(actual_strings)

        if show_display:
            display(actual_strings)
        if gtvals.sum() > 0 or len(accstrs) > max_n:
            break

    #     vcs = ev.embedded_dataset[idxbatch]
    #     astrs = actual_strings    
        vcs = ev.embedded_dataset[np.concatenate(dbidxs)]
        astrs = accstrs

        if show_display:
            show_scores(se, vcs, astrs, target_string=sq)
            
        up.update(vcs, actual_strings=astrs, target_string=sq)

        if show_display:
            show_scores(se, vcs, astrs, target_string=sq)


    frsts = np.where(np.concatenate(gts).reshape(-1))[0]
    if frsts.shape[0] == 0:
        firsts[search_query] = np.inf
    else:
        firsts[search_query] = frsts[0] + 1

In [19]:
import torch.nn.functional as F

In [21]:
from seesaw import *

In [23]:
import ray
#ray.init('auto')
xclip = ModelService(ray.get_actor('clip'))

ValueError: Failed to look up actor with name 'clip'. This could because 1. You are trying to look up a named actor you didn't create. 2. The named actor died. 3. The actor hasn't been created because named actor creation is asynchronous. 4. You did not use a namespace matching the namespace of the actor.

In [24]:
from vsms import *
benchparams = dict(
    #objectnet=dict(loader=objectnet_cropped, idxs=np.load('./data/object_random_idx.npy')[:10000]),
    lvis=dict(loader=lvis_full, idxs=None),

)

def load_ds(evs, dsnames):
    for k,v in tqdm(benchparams.items(), total=len(benchparams)):
        if k in dsnames:
            def closure():
                ev0 = v['loader'](xclip)
                idxs = v['idxs']
                idxs = np.sort(idxs) if idxs is not None else None
                ev = extract_subset(ev0, idxsample=idxs)
                evs[k] = ev
            closure()

In [25]:
evs = {}
load_ds(evs, 'lvis')

  0%|          | 0/1 [00:00<?, ?it/s]

In [26]:
ev0 = evs['lvis']

In [27]:
from vsms import category2query

In [28]:
#del se

In [29]:
#vecs = ev0.embedded_dataset
def lvisloop(ev0, se, category, firsts, max_n, batch_size, tqdm_disabled, feedback=True):
#     batch_size = 10
    n_batches = (int(max_n) // batch_size) + 1
    #tqdm_disabled = True
    sq = category2query('lvis', category)#, category)
    up = Updater(se, lr=.0001, rounds=1)

    ev, class_idxs = get_class_ev(ev0, category, boxes=True)
    evfull = extract_subset(ev0, categories=None, idxsample=class_idxs, boxes=True)
    
    dfds =  DataFrameDataset(ev.box_data[ev.box_data.category == category], 
                             index_var='dbidx', max_idx=class_idxs.shape[0]-1)
    rsz = resize_to_grid(224)
    ds = TxDataset(dfds, tx=lambda tup : rsz(im=None, boxes=tup)[1])
    imds = TxDataset(ev.image_dataset, tx = lambda im : rsz(im=im, boxes=None)[0])

    vec_meta = ev.fine_grained_meta
    vecs = ev.fine_grained_embedding
    #index_path = './data/bdd_10k_allgrains_index.ann'
    index_path = None
    hdb = AugmentedDB(raw_dataset=ev.image_dataset, embedding=ev.embedding, 
        embedded_dataset=vecs, vector_meta=vec_meta, index_path=index_path)
    
    bfq = BoxFeedbackQuery(hdb, batch_size=batch_size, auto_fill_df=None)
    rarr = ev.query_ground_truth[category]
    
    accidxs = []
    accstrs = []
    accvecids = []
#    accvecs = []
    gts = []
    for i in tqdm(range(n_batches), leave=False, disable=tqdm_disabled):
        tvec = se.encode_string(sq)
        idxbatch, other = bfq.query_stateful(mode='dot', vector=tvec, batch_size=batch_size)
        accidxs.append(idxbatch)
        gt = ev.query_ground_truth[category].iloc[idxbatch].values
        gts.append(gt)
        if gt.sum() > 0 or len(gts)*batch_size > max_n:
            break

        if feedback:
            vecids, astrs = get_box_labels(evfull, accidxs[-1])
            accvecids.append(vecids)
            accstrs.extend(astrs)

            def feedback_closure():
                avecs = evfull.fine_grained_embedding[np.concatenate(accvecids)]
                scs = avecs @ tvec.reshape(-1)
                topk = np.argsort(-scs)[:1000]
                trvecs = torch.from_numpy(avecs[topk]).float().to(se.device)
                trstrs = [accstrs[i] for i in topk]
                up.update(trvecs, trstrs, sq)
                del trvecs
                torch.cuda.empty_cache()
                
    
            feedback_closure()
        
    frsts = np.where(np.concatenate(gts).reshape(-1))[0]
    if frsts.shape[0] == 0:
        firsts[category] = np.inf
    else:
        firsts[category] = frsts[0] + 1
    #hdb = EmbeddingDB(raw_dataset=ev.image_dataset, embedding=ev.embedding,embedded_dataset=vecs)

In [30]:
def labelvecs(boxes, meta, iou_cutoff=.001):
    ious = box_iou(boxes, meta)
    
    maxiou = np.max(ious, axis=1)
    argmaxiou = np.argmax(ious, axis=1)
    posn = np.arange(maxiou.shape[0])
    
    boxid = posn[maxiou > iou_cutoff]
    vecid = argmaxiou[maxiou > iou_cutoff]
    cats = list(boxes.iloc[boxid].category.values)
    absvecs = meta.index.values[vecid]
    return absvecs, cats

In [31]:
def get_box_labels(evfull, allidxs):
    relboxes = evfull.box_data[evfull.box_data.dbidx.isin(allidxs)]
    
    vec_meta = evfull.fine_grained_meta[evfull.fine_grained_meta.dbidx.isin(allidxs)]
    vec_meta = vec_meta.drop_duplicates()
    vec_meta = vec_meta.assign(**get_boxes(vec_meta))
    
    vecposns = []
    astrs = []
    for (idx, boxes) in  relboxes.sort_values('dbidx').groupby('dbidx'):
        meta = vec_meta[vec_meta.dbidx == idx]
        absvecs, cats = labelvecs(boxes, meta)
        vecposns.append(absvecs)
        strs = [category2query('lvis',c) for c in cats]
        astrs.extend(strs)
        
    avecids = np.concatenate(vecposns) 
    return avecids, astrs

In [32]:
### bfq right now returns image ids.

### what id ideally do: to simulate user giving localized feedback...
### user gives feedback for all images. (can be also only about specific regions)
# 0. assume we highlight regions to user. user can give negative feedback with a comment indicating the 
## type of mistake.
# 1. access to vec id list and box metadata. (ideally includes previous vecs)
### this creates a mapping of vec ids to texts.
# 2. we match ground truth labels for other classes to the top vecs in the shown images.

In [33]:
def show_scores(se, vecs, actual_strings, target_string):
    with torch.no_grad():
        se.model.eval()
        scs,stids,rawstrs = forward(se, vecs, actual_strings, target_string=target_string)
    scdf = pd.DataFrame({st:col  for  st,col in zip(rawstrs,scs.cpu().numpy().transpose())})
    display(scdf.style.highlight_max(axis=1))

def get_feedback(idxbatch):
    strids = np.where(ev.query_ground_truth.iloc[idxbatch])[1]
    strs = ev.query_ground_truth.columns[strids]
    strs = [search_terms['objectnet'][fbstr] for fbstr in strs.values]
    return strs

In [34]:
curr_firsts = pd.read_parquet('./data/cats_lvis_ordered.parquet')

In [35]:
curr_firsts = curr_firsts.sort_values('nfirst_x', ascending=False)

In [36]:
#del se
#torch.cuda.empty_cache()

In [39]:
firsts = {}
batch_size = 10
cf = curr_firsts[curr_firsts.nfirst_x > batch_size]
se = StringEncoder(model=mod, device='cpu')
cf = cf[cf.nfirst_x < np.inf]
for x in tqdm(cf.itertuples(), total=cf.shape[0]):
    lvisloop(ev0, se, x.category, firsts=firsts, max_n=min(x.nfirst_x*5, 200), 
             batch_size=batch_size, tqdm_disabled=False, feedback=True)
    se.reset()
    torch.cuda.empty_cache()
    print(firsts[x.category], x.nfirst_x)
    if x.nfirst_x <= batch_size:
        break

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

90 97.0


  0%|          | 0/21 [00:00<?, ?it/s]

93 97.0


  0%|          | 0/21 [00:00<?, ?it/s]

69 90.0


  0%|          | 0/21 [00:00<?, ?it/s]

119 85.0


  0%|          | 0/21 [00:00<?, ?it/s]

194 80.0


  0%|          | 0/21 [00:00<?, ?it/s]

57 70.0


  0%|          | 0/21 [00:00<?, ?it/s]

56 68.0


  0%|          | 0/21 [00:00<?, ?it/s]

58 67.0


  0%|          | 0/21 [00:00<?, ?it/s]

127 67.0


  0%|          | 0/21 [00:00<?, ?it/s]

41 61.0


  0%|          | 0/21 [00:00<?, ?it/s]

119 61.0


  0%|          | 0/21 [00:00<?, ?it/s]

44 52.0


  0%|          | 0/21 [00:00<?, ?it/s]

45 50.0


  0%|          | 0/21 [00:00<?, ?it/s]

42 49.0


  0%|          | 0/21 [00:00<?, ?it/s]

44 47.0


  0%|          | 0/21 [00:00<?, ?it/s]

44 46.0


  0%|          | 0/21 [00:00<?, ?it/s]

40 43.0


  0%|          | 0/21 [00:00<?, ?it/s]

39 40.0


  0%|          | 0/21 [00:00<?, ?it/s]

39 40.0


  0%|          | 0/21 [00:00<?, ?it/s]

40 40.0


  0%|          | 0/20 [00:00<?, ?it/s]

33 39.0


  0%|          | 0/20 [00:00<?, ?it/s]

40 38.0


  0%|          | 0/19 [00:00<?, ?it/s]

93 37.0


  0%|          | 0/19 [00:00<?, ?it/s]

38 37.0


  0%|          | 0/18 [00:00<?, ?it/s]

31 35.0


  0%|          | 0/17 [00:00<?, ?it/s]

30 33.0


  0%|          | 0/15 [00:00<?, ?it/s]

26 29.0


  0%|          | 0/15 [00:00<?, ?it/s]

24 28.0


  0%|          | 0/15 [00:00<?, ?it/s]

25 28.0


  0%|          | 0/14 [00:00<?, ?it/s]

25 26.0


  0%|          | 0/14 [00:00<?, ?it/s]

30 26.0


  0%|          | 0/13 [00:00<?, ?it/s]

24 24.0


  0%|          | 0/12 [00:00<?, ?it/s]

24 23.0


  0%|          | 0/12 [00:00<?, ?it/s]

19 23.0


  0%|          | 0/12 [00:00<?, ?it/s]

22 22.0


  0%|          | 0/12 [00:00<?, ?it/s]

22 22.0


  0%|          | 0/11 [00:00<?, ?it/s]

21 21.0


  0%|          | 0/11 [00:00<?, ?it/s]

20 20.0


  0%|          | 0/10 [00:00<?, ?it/s]

17 19.0


  0%|          | 0/10 [00:00<?, ?it/s]

19 19.0


  0%|          | 0/10 [00:00<?, ?it/s]

18 19.0


  0%|          | 0/10 [00:00<?, ?it/s]

18 18.0


  0%|          | 0/10 [00:00<?, ?it/s]

17 18.0


  0%|          | 0/9 [00:00<?, ?it/s]

15 17.0


  0%|          | 0/9 [00:00<?, ?it/s]

17 17.0


  0%|          | 0/9 [00:00<?, ?it/s]

19 17.0


  0%|          | 0/9 [00:00<?, ?it/s]

18 17.0


  0%|          | 0/9 [00:00<?, ?it/s]

13 17.0


  0%|          | 0/9 [00:00<?, ?it/s]

17 16.0


  0%|          | 0/9 [00:00<?, ?it/s]

16 16.0


  0%|          | 0/8 [00:00<?, ?it/s]

16 15.0


  0%|          | 0/8 [00:00<?, ?it/s]

15 15.0


  0%|          | 0/8 [00:00<?, ?it/s]

15 15.0


  0%|          | 0/8 [00:00<?, ?it/s]

13 14.0


  0%|          | 0/7 [00:00<?, ?it/s]

14 13.0


  0%|          | 0/7 [00:00<?, ?it/s]

13 13.0


  0%|          | 0/7 [00:00<?, ?it/s]

4 12.0


  0%|          | 0/6 [00:00<?, ?it/s]

11 11.0


  0%|          | 0/6 [00:00<?, ?it/s]

11 11.0


  0%|          | 0/6 [00:00<?, ?it/s]

11 11.0


In [None]:
%debug

In [47]:
rdf = pd.concat([pd.Series(firsts).rename('feedback'), cf[['category', 'nfirst_x']].set_index('category')['nfirst_x'].rename('no_feedback')], axis=1)

In [48]:
rdf

Unnamed: 0,feedback,no_feedback
masher,90,97.0
car battery,93,97.0
hamper,69,90.0
skullcap,119,85.0
chap,194,80.0
die,57,70.0
sawhorse,56,68.0
boom microphone,58,67.0
breechcloth,127,67.0
domestic ass,41,61.0


In [49]:
((rdf.feedback < rdf.no_feedback).mean(), 
(rdf.feedback == rdf.no_feedback).mean(), 
(rdf.feedback > rdf.no_feedback).mean())

(0.5, 0.26666666666666666, 0.23333333333333334)

In [50]:
#rdf.to_parquet('./data/lvis_nfirst_verbal.parquet')