In [2]:
import threading
import torch

In [2]:
def task():
    print("hi from a thread", threading.get_native_id()) #id is useful for debugging

t = threading.Thread(target = task) #target = name of function u want to run on this thread
t.start() #start thread
#now this program has 2 threads. main and this
print("hi from the main thread, which is ", threading.get_native_id())

hi from a threadhi from the main thread, which is  186488
 186506


In [3]:
total = 0

def adder(count):
    global total
    for i in range(count):
        total += i
        
t = threading.Thread(target=adder, args=[1000000]) #start another thread, pass in count value = 10000 into the parameter
t.start()
# wait until the thread is done (otherwise we get an inaccurate count when we do big numbers for count! nondeterministic bug. this is bc the loop takes a while and one thread is running that n it might context switch to main thread adn print off the current total)
    #race condition - any time when whether or not output is correct depends on what scheduler decides
t.join() #have fixed the race conditions. those can be tricky bc sometimes they happen only 1 in a million times
print(total)

499999500000


In [7]:
total = torch.tensor(0, dtype = torch.int32)

def count_up(count):
    global total
    for i in range(count):
        total+=1 #when this runs on python VM, this gets broken into  several smaller instructions
        
t1 = threading.Thread(target = count_up, args = [1000000])
t2 = threading.Thread(target = count_up, args = [1000000])
t1.start()
t2.start()
t1.join()
t2.join()

In [8]:
total #that's not right!

tensor(1999560, dtype=torch.int32)

In [9]:
#let's see why this error is happening by seeing what's going on in the python VM
import dis
dis.dis("")

  1           0 LOAD_CONST               0 (None)
              2 RETURN_VALUE


In [10]:
dis.dis("total+=1")
#just adding 1 to a variable is a 4 step process (until STORE_NAME)

#see race conditions worksheet

  1           0 LOAD_NAME                0 (total)
              2 LOAD_CONST               0 (1)
              4 INPLACE_ADD
              6 STORE_NAME               0 (total)
              8 LOAD_CONST               1 (None)
             10 RETURN_VALUE


In [3]:
import time
total = torch.tensor(0, dtype=torch.int32)

def count_up(count):
    global total
    for i in range(count):
        total += 1

t1 = threading.Thread(target=count_up, args=[1000000])
t2 = threading.Thread(target=count_up, args=[1000000])
start = time.time()
t1.start()
t2.start()
t1.join()
t2.join()
end = time.time()

print("seconds", end-start)
total

seconds 10.628129243850708


tensor(1999749, dtype=torch.int32)

In [4]:
# add locking
import time

total = torch.tensor(0, dtype=torch.int32)
lock = threading.Lock() # protects the global variable "total"

def count_up(count):
    global total
    for i in range(count):
        lock.acquire() #expensive operation! don't overuse
        total += 1
        lock.release()

t1 = threading.Thread(target=count_up, args=[1000000])
t2 = threading.Thread(target=count_up, args=[1000000])
start = time.time()
t1.start()
t2.start()
t1.join()
t2.join()
end = time.time()

print("seconds", end-start)
total
#ran slower, but right result!
#let's try saving time next with the coarse grained lock

seconds 36.16609764099121


tensor(2000000, dtype=torch.int32)

In [None]:
# add locking (coarse grained - hold lock for a long time)
import time

total = torch.tensor(0, dtype=torch.int32)
lock = threading.Lock() # protects total

def count_up(count):
    global total
    lock.acquire()
    for i in range(count):
        total += 1
    lock.release()

t1 = threading.Thread(target=count_up, args=[1000000])
t2 = threading.Thread(target=count_up, args=[1000000])
start = time.time()
t1.start()
t2.start()
t1.join()
t2.join()
end = time.time()

print("seconds", end-start)
total

## Bad bank

In [13]:
bank_accounts = {"x": 25, "y": 100, "z": 200} # in dollars
lock = threading.Lock() # protects bank_accounts

def transfer(src, dst, amount):
    lock.acquire()
    success = False
    if bank_accounts[src] >= amount:
        bank_accounts[src] -= amount
        bank_accounts[dst] += amount
        success = True
    print("transferred" if success else "denied")
    lock.release()

In [14]:
transfer("x", "y", 3) #this will work, same with if it had -3

transferred


In [15]:
print(bank_accounts)
transfer("w", "y", -3) #this transfer will lead to error
print(bank_accounts)
#if this fails, the lock is never released! so will always get exception.

{'x': 22, 'y': 103, 'z': 200}


KeyError: 'w'

In [None]:
transfer("x", "y", 1)


In [None]:
print(bank_accounts)
transfer("x", "y", 3)
print(bank_accounts)

In [None]:
#fix the exception issue where the lock is never released
bank_accounts = {"x": 25, "y": 100, "z": 200} # in dollars
lock = threading.Lock() # protects bank_accounts

def transfer(src, dst, amount):
    #using a context manager to handle that exception!
    with lock: #   calls lock.acquire() for me
        success = False
        if bank_accounts[src] >= amount:
            bank_accounts[src] -= amount
            bank_accounts[dst] += amount
            success = True
        print("transferred" if success else "denied")
    # calls lock.release() for me

In [None]:
transfer("x", "y", 1)

In [None]:
transfer("w", "y", 1) #error but this time the with calls lock.release for us

In [None]:
transfer("x", "y", 1)